From 18e1326ba1c184857e08b49ea6db016f3f40cd4f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Harald=20B=C3=B6geholz?= Date: Mon, 16 Sep 2024 17:08:04 +0200 Subject: [PATCH] Bugfix for jobs running on multiple nodes. Using BATCHHOST instead of NODELIST allows multi-node LLM jobs. Also fixing time parsing for job times longer than a day. Also fixing typo in cloud_interface.sh --- cloud_interface.sh | 4 ++-- scheduler.py | 26 ++++++++++++++++++-------- 2 files changed, 20 insertions(+), 10 deletions(-) diff --git a/cloud_interface.sh b/cloud_interface.sh index 34353fd..71cef53 100755 --- a/cloud_interface.sh +++ b/cloud_interface.sh @@ -28,7 +28,7 @@ then # Check if last_execution_time file exists if [ -f last_update ]; then # Read the last execution time from the file - last_update=$(cat .last_update) + last_update=$(cat last_update) else # If the file doesn't exist, initialize the last execution time to 0 last_update=0 @@ -202,4 +202,4 @@ else else printf "HTTP/1.1 503 Service Unavailable\r\nContent-Type: text/html; charset=UTF-8\r\nDate: $(date -R)\r\nServer: KISSKI\r\n\r\nConnection to model broke\r\n" fi -fi \ No newline at end of file +fi diff --git a/scheduler.py b/scheduler.py index a713366..27e0937 100755 --- a/scheduler.py +++ b/scheduler.py @@ -40,7 +40,7 @@ def get_squeue_status(): squeue_output = subprocess.run( [squeue_path, '--me', '-h', '--name=service-backend', - '--format="{\"JOBID\": \"%.18i\", \"STATE\": \"%.2t\", \"TIME\": \"%.10M\", \"TIME_LIMIT\": \"%.9l\", \"NODELIST\": \"%N\"}"'], + '--format="{\"JOBID\": \"%.18i\", \"STATE\": \"%.2t\", \"TIME\": \"%.10M\", \"TIME_LIMIT\": \"%.9l\", \"BATCHHOST\": \"%B\"}"'], stdout=subprocess.PIPE, stderr=subprocess.PIPE).stdout.decode('utf-8') lines = squeue_output.split("\n") lines = [" ".join(line.split()).strip("\"") for line in lines] @@ -57,12 +57,22 @@ def generate_random_port_number(excluded: set) -> int: def squeue_time_to_timedelta(time_str): - try: - minutes, seconds = map(int, time_str.split(':')) - except: - hours, minutes, seconds = map(int, time_str.split(':')) - return timedelta(hours=hours, minutes=minutes, seconds=seconds) - return timedelta(minutes=minutes, seconds=seconds) + if '-' in time_str: + days_part, time_part = time_str.split('-') + days = int(days_part) + # Process the rest as hours:minutes:seconds + hours, minutes, seconds = map(int, time_part.split(':')) + else: + days = 0 + time_parts = time_str.split(':') + if len(time_parts) == 2: # MM:SS + hours = 0 + minutes, seconds = map(int, time_parts) + elif len(time_parts) == 3: # HH:MM:SS + hours, minutes, seconds = map(int, time_parts) + else: + raise ValueError("Invalid time format: {time_str}") + return timedelta(days=days, hours=hours, minutes=minutes, seconds=seconds) def test_readiness(host, port): @@ -112,7 +122,7 @@ def from_squeue(self, squeue_line): self.status = squeue_json["STATE"].strip() self.time = squeue_json["TIME"].strip() self.time_limit = squeue_json["TIME_LIMIT"].strip() - self.host = squeue_json["NODELIST"].strip() + self.host = squeue_json["BATCHHOST"].strip() def is_about_to_expire(self): #time = squeue_time_to_timedelta(self.time)