"""A class for the TAMS data as an SQL database using SQLAlchemy."""
from __future__ import annotations
import json
import logging
from pathlib import Path
from sqlalchemy import create_engine
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.orm import DeclarativeBase
from sqlalchemy.orm import Mapped
from sqlalchemy.orm import mapped_column
from sqlalchemy.orm import sessionmaker
_logger = logging.getLogger(__name__)
[docs]
class Base(DeclarativeBase):
"""A base class for the tables."""
[docs]
class Trajectory(Base):
"""A table storing the active trajectories."""
__tablename__ = "trajectories"
[docs]
id: Mapped[int] = mapped_column(primary_key=True, autoincrement=True)
[docs]
traj_file: Mapped[str] = mapped_column(nullable=False)
[docs]
status: Mapped[str] = mapped_column(default="idle", nullable=False)
[docs]
class ArchivedTrajectory(Base):
"""A table storing the archived trajectories."""
__tablename__ = "archived_trajectories"
[docs]
id: Mapped[int] = mapped_column(primary_key=True, autoincrement=True)
[docs]
traj_file: Mapped[str] = mapped_column(nullable=False)
[docs]
class SplittingIterations(Base):
"""A table storing the splitting iterations."""
__tablename__ = "splitting_iterations"
[docs]
id: Mapped[int] = mapped_column(primary_key=True, autoincrement=True)
[docs]
split_id: Mapped[int] = mapped_column(nullable=False)
[docs]
rst_traj_count: Mapped[int] = mapped_column(nullable=False)
[docs]
rst_traj_ids: Mapped[str] = mapped_column(nullable=False)
[docs]
from_traj_ids: Mapped[str] = mapped_column(nullable=False)
[docs]
min_vals: Mapped[str] = mapped_column(nullable=False)
[docs]
min_max: Mapped[str] = mapped_column(nullable=False)
[docs]
valid_statuses = ["locked", "idle", "completed"]
[docs]
class SQLFile:
"""An SQL file.
Allows atomic access to an SQL database from all
the workers.
Note: TAMS works with Python indexing starting at 0,
while SQL indexing starts at 1. Trajectory ID is
updated accordingly when accessing/updating the DB.
Attributes:
_file_name : The file name
"""
def __init__(self, file_name: str, in_memory: bool = False, ro_mode: bool = False) -> None:
"""Initialize the file.
Args:
file_name : The file name
in_memory: a bool to trigger in-memory creation
ro_mode: a bool to trigger read-only access to the database
"""
self._file_name = file_name
if in_memory:
self._engine = create_engine(f"sqlite:///{file_name}?mode=memory&cache=shared&uri=true", echo=False)
else:
self._engine = (
create_engine(f"sqlite:///{file_name}?mode=ro&uri=true", echo=False)
if ro_mode
else create_engine(f"sqlite:///{file_name}", echo=False)
)
self._Session = sessionmaker(bind=self._engine)
self._init_db()
def _init_db(self) -> None:
"""Initialize the tables of the file.
Raises:
RuntimeError : If a connection to the DB could not be acquired
"""
try:
Base.metadata.create_all(self._engine)
except SQLAlchemyError as e:
err_msg = "Failed to initialize DB schema"
_logger.exception(err_msg)
raise RuntimeError(err_msg) from e
[docs]
def add_trajectory(self, traj_file: str) -> None:
"""Add a new trajectory to the DB.
Args:
traj_file : The trajectory file of that trajectory
Raises:
SQLAlchemyError if the DB could not be accessed
"""
session = self._Session()
try:
new_traj = Trajectory(traj_file=traj_file)
session.add(new_traj)
session.commit()
except SQLAlchemyError:
session.rollback()
_logger.exception("Failed to add trajectory")
raise
finally:
session.close()
[docs]
def update_trajectory_file(self, traj_id: int, traj_file: str) -> None:
"""Update a trajectory file in the DB.
Args:
traj_id : The trajectory id
traj_file : The new trajectory file of that trajectory
Raises:
SQLAlchemyError if the DB could not be accessed
"""
session = self._Session()
try:
# SQL indexing starts at 1, adjust ID
db_id = traj_id + 1
traj = session.query(Trajectory).filter(Trajectory.id == db_id).one()
traj.traj_file = mapped_column(traj_file)
session.commit()
except SQLAlchemyError:
session.rollback()
err_msg = f"Failed to update trajectory {traj_id}"
_logger.exception(err_msg)
raise
finally:
session.close()
[docs]
def lock_trajectory(self, traj_id: int, allow_completed_lock: bool = False) -> bool:
"""Set the status of a trajectory to "locked" if possible.
Args:
traj_id : The trajectory id
allow_completed_lock : Allow to lock a "completed" trajectory
Return:
True if the trajectory was successfully locked, False otherwise
Raises:
ValueError if the trajectory with the given id does not exist
SQLAlchemyError if the DB could not be accessed
"""
session = self._Session()
try:
# SQL indexing starts at 1, adjust ID
db_id = traj_id + 1
traj = session.query(Trajectory).filter(Trajectory.id == db_id).with_for_update().one_or_none()
if traj:
allowed_status = ["idle", "completed"] if allow_completed_lock else ["idle"]
if traj.status in allowed_status:
traj.status = "locked"
session.commit()
return True
return False
err_msg = f"Trajectory {traj_id} does not exist"
_logger.error(err_msg)
raise ValueError(err_msg)
except SQLAlchemyError:
session.rollback()
_logger.exception("Failed to lock trajectory")
raise
finally:
session.close()
[docs]
def mark_trajectory_as_completed(self, traj_id: int) -> None:
"""Set the status of a trajectory to "completed" if possible.
Args:
traj_id : The trajectory id
Raises:
ValueError if the trajectory with the given id does not exist
SQLAlchemyError if the DB could not be accessed
"""
session = self._Session()
try:
# SQL indexing starts at 1, adjust ID
db_id = traj_id + 1
traj = session.query(Trajectory).filter(Trajectory.id == db_id).one_or_none()
if traj:
if traj.status in ["locked"]:
traj.status = "completed"
session.commit()
else:
warn_msg = f"Attempting to mark completed Trajectory {traj_id} already in status {traj.status}."
_logger.warning(warn_msg)
else:
err_msg = f"Trajectory {traj_id} does not exist"
_logger.error(err_msg)
raise ValueError(err_msg)
except SQLAlchemyError:
session.rollback()
_logger.exception("Failed to mark trajectory as completed")
raise
finally:
session.close()
[docs]
def release_trajectory(self, traj_id: int) -> None:
"""Set the status of a trajectory to "idle" if possible.
Args:
traj_id : The trajectory id
Raises:
ValueError if the trajectory with the given id does not exist
"""
session = self._Session()
try:
# SQL indexing starts at 1, adjust ID
db_id = traj_id + 1
traj = session.query(Trajectory).filter(Trajectory.id == db_id).one_or_none()
if traj:
if traj.status in ["locked"]:
traj.status = "idle"
session.commit()
else:
warn_msg = f"Attempting to release Trajectory {traj_id} already in status {traj.status}."
_logger.warning(warn_msg)
else:
err_msg = f"Trajectory {traj_id} does not exist"
_logger.error(err_msg)
raise ValueError(err_msg)
except SQLAlchemyError:
session.rollback()
_logger.exception("Failed to release trajectory")
raise
finally:
session.close()
[docs]
def get_trajectory_count(self) -> int:
"""Get the number of trajectories in the DB.
Returns:
The number of trajectories
"""
session = self._Session()
try:
return session.query(Trajectory).count()
except SQLAlchemyError:
session.rollback()
_logger.exception("Failed to count the number of trajectories")
raise
finally:
session.close()
[docs]
def fetch_trajectory(self, traj_id: int) -> str:
"""Get the trajectory file of a trajectory.
Args:
traj_id : The trajectory id
Return:
The trajectory file
Raises:
ValueError if the trajectory with the given id does not exist
"""
session = self._Session()
try:
# SQL indexing starts at 1, adjust ID
db_id = traj_id + 1
traj = session.query(Trajectory).filter(Trajectory.id == db_id).one_or_none()
if traj:
tfile: str = traj.traj_file
return tfile
err_msg = f"Trajectory {traj_id} does not exist"
_logger.error(err_msg)
raise ValueError(err_msg)
except SQLAlchemyError:
session.rollback()
_logger.exception("Failed to fetch trajectory")
raise
finally:
session.close()
[docs]
def release_all_trajectories(self) -> None:
"""Release all trajectories in the DB."""
session = self._Session()
try:
session.query(Trajectory).update({"status": "idle"})
session.commit()
except SQLAlchemyError:
session.rollback()
_logger.exception("Failed to release all trajectories")
finally:
session.close()
[docs]
def archive_trajectory(self, traj_file: str) -> None:
"""Add a new trajectory to the archive container.
Args:
traj_file : The trajectory file of that trajectory
"""
session = self._Session()
try:
new_traj = ArchivedTrajectory(traj_file=traj_file)
session.add(new_traj)
session.commit()
except SQLAlchemyError:
session.rollback()
_logger.exception("Failed to archive trajectory")
finally:
session.close()
[docs]
def fetch_archived_trajectory(self, traj_id: int) -> str:
"""Get the trajectory file of a trajectory in the archive.
Args:
traj_id : The trajectory id
Return:
The trajectory file
Raises:
ValueError if the trajectory with the given id does not exist
"""
session = self._Session()
try:
# SQL indexing starts at 1, adjust ID
db_id = traj_id + 1
traj = session.query(ArchivedTrajectory).filter(ArchivedTrajectory.id == db_id).one_or_none()
if traj:
tfile: str = traj.traj_file
return tfile
err_msg = f"Trajectory {traj_id} does not exist"
_logger.error(err_msg)
raise ValueError(err_msg)
except SQLAlchemyError:
session.rollback()
_logger.exception("Failed to fetch archived trajectory")
raise
finally:
session.close()
[docs]
def get_archived_trajectory_count(self) -> int:
"""Get the number of trajectories in the archive.
Returns:
The number of trajectories
"""
session = self._Session()
try:
return session.query(ArchivedTrajectory).count()
except SQLAlchemyError:
session.rollback()
_logger.exception("Failed to count the number of archived trajectories")
raise
finally:
session.close()
[docs]
def add_splitting_data(
self, k: int, n_rst: int, rst_ids: list[int], from_ids: list[int], min_vals: list[float], min_max: list[float]
) -> None:
"""Add a new splitting data to the DB.
Args:
k : The splitting iteration index
n_rst : The number of restarted trajectories
rst_ids : The list of restarted trajectory ids
from_ids : The list of trajectories used to restart
min_vals : The list of minimum values
min_max : The score minimum and maximum values
"""
session = self._Session()
try:
new_split = SplittingIterations(
split_id=k,
rst_traj_count=n_rst,
rst_traj_ids=" ".join(str(x) for x in rst_ids),
from_traj_ids=" ".join(str(x) for x in from_ids),
min_vals=" ".join(str(x) for x in min_vals),
min_max=" ".join(str(x) for x in min_max),
)
session.add(new_split)
session.commit()
except SQLAlchemyError:
session.rollback()
_logger.exception("Failed to add splitting data")
raise
finally:
session.close()
[docs]
def dump_file_json(self) -> None:
"""Dump the content of the trajectory table to a json file."""
db_data = {}
session = self._Session()
try:
db_data["trajectories"] = {
traj.id - 1: {"file": traj.traj_file, "status": traj.status} for traj in session.query(Trajectory).all()
}
db_data["archived_trajectories"] = {
traj.id - 1: {"file": traj.traj_file} for traj in session.query(ArchivedTrajectory).all()
}
db_data["splitting_data"] = {
split.id: {
"k": str(split.split_id),
"min_max_start": split.min_max,
"n_rst": str(split.rst_traj_count),
"rst_ids": split.rst_traj_ids,
"from_ids": split.from_traj_ids,
"min_vals": split.min_vals,
}
for split in session.query(SplittingIterations).all()
}
except SQLAlchemyError:
session.rollback()
_logger.exception("Failed to query the content of the DB")
raise
finally:
session.close()
json_file = Path(f"{Path(self._file_name).stem}.json")
with json_file.open("w") as f:
json.dump(db_data, f, indent=2)