Source code for PARyOpt.evaluators.async_sbatch

"""
---
    Copyright (c) 2018 Baskar Ganapathysubramanian, Balaji Sesha Sarath Pokuri
    
    Permission is hereby granted, free of charge, to any person obtaining a copy
    of this software and associated documentation files (the "Software"), to deal
    in the Software without restriction, including without limitation the rights
    to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
    copies of the Software, and to permit persons to whom the Software is
    furnished to do so, subject to the following conditions:

    The above copyright notice and this permission notice shall be included in all
    copies or substantial portions of the Software.

    THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
    IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
    FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
    AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
    LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
    OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
    SOFTWARE.
---
"""

## --- end license text --- ##
"""
SLURM scheduler asynchronous evaluator sub-class
"""
from datetime import datetime, timedelta
import re
import tempfile
import numpy as np
import os
import io
from .connection import Connection, Host
from typing import Callable, Union

from .async import AsyncFunctionEvaluator, EvaluationFailed, ValueNotReady, EvaluateAgain


# helper local parse result functions
[docs]def VALUE_FROM_FILE(filename): def value_from_file(local_dir: str, remote_dir: str, conn: Connection, x: np.array) -> float: fl = io.BytesIO() conn.sftp().getfo(os.path.join(remote_dir, filename), fl) return float(fl.getvalue().decode(encoding='utf-8').strip())
return value_from_file
[docs]class AsyncSbatchEvaluator(AsyncFunctionEvaluator): """ Class for cost functions that evaluated by launching a job on a remote machine running the SLURM job scheduler. :param host: Host object containing the credentials for the server to connect to :param job_generator: callable that sets up the run directory for a given x (by e.g. writing config files). \ It will be passed two arguments: the job directory and the point to evaluate at (x). :param job_script: either a string (for a fixed job script), or a callable that returns the job script string. \ In the latter case, job_script will be passed two arguments: the job directory and the point to evaluate at (x). :param remote_parse_result: callable that returns the cost function evaluated at x. \ It will be passed two arguments: the job directory and the point to evaluate at (x). \ This will be called after the command returned by run_cmd_generator has terminated (gracefully or otherwise). \ If the process did not terminate successfully or the result is otherwise unavailable, parse_result \ should raise any exception. This will signal the optimization routine to not try this point again. \ This function will be executed *on the remote host*. This requires the remote host to have a matching version \ of Python installed and the dill module. :param lcl_parse_result: callable that returns the cost function evaluated at X. \ It is passed three arguments: the local job dir, remote job dir, the Connection object to the remote, and X. \ It is executed on the local machine. This does not require the remote to have Python installed. :param lcl_jobs_dir: optional base directory to generate jobs in - default is $PWD/opt_jobs. :param remote_jobs_dir: optional base directory to upload jobs to - default is $HOME/paryopt_jobs. :param squeue_update_rate: minimum time between squeue calls. Lower for better job latency, higher to be more \ polite :param required_fraction: fraction of points which must complete before continuing to the next iteration see \ AsyncEvaluator for more info and implementation :param max_pending: maximum simultaneous queued jobs, defaults to 25, see AsyncEvaluator for implementation \ """ def __init__(self, host: Host, job_generator: Callable[[str, np.array], None], job_script: Union[str, Callable[[str, np.array], str]], lcl_parse_result: Callable[[str, str, Connection, np.array], float] = None, remote_parse_result: Callable[[str, np.array], float] = None, lcl_jobs_dir: str = os.path.join(os.getcwd(), 'opt_jobs'), squeue_update_rate: timedelta = timedelta(seconds=30), remote_jobs_dir: str = 'paryopt_jobs', required_fraction=1.0, max_pending=25): super().__init__(required_fraction=required_fraction, max_pending=max_pending) if (lcl_parse_result and remote_parse_result) or (not lcl_parse_result and not remote_parse_result): raise Exception('You must specify either a local parse result function or remote parse result function.') self.host = host self.job_generator = job_generator self.job_script = job_script self.local_parse_result = lcl_parse_result self.remote_parse_result = remote_parse_result self.lcl_jobs_dir = lcl_jobs_dir self.remote_jobs_dir = remote_jobs_dir self.connection = Connection() self.connection.connect(self.host) self._last_squeue = "" self._last_squeue_upd_time = datetime.min self._squeue_upd_rate = squeue_update_rate def _update_squeue(self): stdout, stderr = self.connection.exec_command('squeue -u `whoami`') self._last_squeue = stdout self._last_squeue_upd_time = datetime.now()
[docs] def squeue(self): if (datetime.now() - self._last_squeue_upd_time) >= self._squeue_upd_rate: self._update_squeue()
return self._last_squeue
[docs] def start(self, x: np.array) -> (str, str, int, datetime): """ Generate job directory on local machine, fill in data related to the job like directory, job id and submit time """ if not os.path.exists(self.lcl_jobs_dir): os.mkdir(self.lcl_jobs_dir) prefix = 'job_' + ('_'.join([str(v) for v in x])) + '_' directory = tempfile.mkdtemp(prefix=prefix, dir=self.lcl_jobs_dir) self.job_generator(directory, x) # generate job script in directory/job.sh job_script_path = os.path.join(directory, 'job.sh') with open(job_script_path, 'w') as f: job_script = self.job_script if callable(job_script): job_script = job_script(directory, x) f.write(job_script) # copy it to the remote machine remote_dir = os.path.join(self.remote_jobs_dir, os.path.basename(os.path.normpath(directory))) self.connection.put_dir(directory, remote_dir) # run sbatch from inside the directory stdout, stderr = self.connection.exec_command('sbatch job.sh', cwd=remote_dir) # grab the job ID from the output (assuming it was submitted successfully) match = re.match(r'Submitted batch job (\d+)', stdout) if not match: raise RuntimeError('sbatch in ' + remote_dir + ' failed: ' + stderr) job_id = match.group(1)
return directory, remote_dir, job_id, datetime.now()
[docs] def check_for_result(self, x: np.array, data: (str, str, int, datetime)) -> \ Union[ValueNotReady, EvaluateAgain, EvaluationFailed, float]: """ checks for result if the jobid is complete and ping time is after update rate :param x: location of evaluation :param data: data related to the location. Typically this is directory information, job id and submit time :return: one of ValueNotReady, EvaluateAgain, EvaluationFailed or float """ lcl_dir = data[0] remote_dir = data[1] job_id = data[2] submit_time = data[3] # if squeue hasn't had a chance to update since submission, wait until it does # (this effectively enforces a minimum time for jobs) if (datetime.now() - submit_time) < self._squeue_upd_rate: return ValueNotReady() queue = self.squeue() status = re.search(r'\s*' + re.escape(job_id) + r'\s+\S+\s+\S+\s+\S+\s+(\S+)\s+', queue) if status and status.group(1) != 'CG': print('Waiting for ' + lcl_dir + ' to complete (status: ' + status.group(1) + ')') return ValueNotReady() print(' ' + lcl_dir + ' completed') # it's either no longer in the queue or giving a 'complete' status, so it's done try: if self.local_parse_result: return self.local_parse_result(lcl_dir, remote_dir, self.connection, x) elif self.remote_parse_result: return self.connection.call_on_remote(self.remote_parse_result, remote_dir, x, remote_cwd=remote_dir) except Exception as err: if __debug__: raise else:
return EvaluationFailed(err)