"""A class for the TAMS data as an SQL database using SQLAlchemy."""from__future__importannotationsimportjsonimportloggingfrompathlibimportPathfromsqlalchemyimportcreate_enginefromsqlalchemy.excimportSQLAlchemyErrorfromsqlalchemy.ormimportDeclarativeBasefromsqlalchemy.ormimportMappedfromsqlalchemy.ormimportmapped_columnfromsqlalchemy.ormimportsessionmaker_logger=logging.getLogger(__name__)
[docs]classBase(DeclarativeBase):"""A base class for the tables."""
[docs]classTrajectory(Base):"""A table storing the trajectories."""__tablename__="trajectories"
[docs]classSQLFile:"""An SQL file. Allows atomic access to an SQL database from all the workers. Note: TAMS works with Python indexing starting at 0, while SQL indexing starts at 1. Trajectory ID is updated accordingly when accessing/updating the DB. Attributes: _file_name : The file name """def__init__(self,file_name:str)->None:"""Initialize the file. Args: file_name : The file name """self._file_name=file_nameself._engine=create_engine(f"sqlite:///{file_name}",echo=False)self._Session=sessionmaker(bind=self._engine)self._init_db()def_init_db(self)->None:"""Initialize the tables of the file. Raises: RuntimeError : If a connection to the DB could not be acquired """try:Base.metadata.create_all(self._engine)exceptSQLAlchemyErrorase:err_msg="Failed to initialize DB schema"_logger.exception(err_msg)raiseRuntimeError(err_msg)frome
[docs]defadd_trajectory(self,traj_file:str)->None:"""Add a new trajectory to the DB. Args: traj_file : The trajectory file of that trajectory Raises: SQLAlchemyError if the DB could not be accessed """session=self._Session()try:new_traj=Trajectory(traj_file=traj_file)session.add(new_traj)session.commit()exceptSQLAlchemyError:session.rollback()_logger.exception("Failed to add trajectory")raisefinally:session.close()
[docs]defupdate_trajectory_file(self,traj_id:int,traj_file:str)->None:"""Update a trajectory file in the DB. Args: traj_id : The trajectory id traj_file : The new trajectory file of that trajectory Raises: SQLAlchemyError if the DB could not be accessed """session=self._Session()try:# SQL indexing starts at 1, adjust IDdb_id=traj_id+1traj=session.query(Trajectory).filter(Trajectory.id==db_id).one()traj.traj_file=mapped_column(traj_file)session.commit()exceptSQLAlchemyError:session.rollback()err_msg=f"Failed to update trajectory {traj_id}"_logger.exception(err_msg)raisefinally:session.close()
[docs]deflock_trajectory(self,traj_id:int,allow_completed_lock:bool=False)->bool:"""Set the status of a trajectory to "locked" if possible. Args: traj_id : The trajectory id allow_completed_lock : Allow to lock a "completed" trajectory Return: True if the trajectory was successfully locked, False otherwise Raises: ValueError if the trajectory with the given id does not exist SQLAlchemyError if the DB could not be accessed """session=self._Session()try:# SQL indexing starts at 1, adjust IDdb_id=traj_id+1traj=session.query(Trajectory).filter(Trajectory.id==db_id).with_for_update().one_or_none()iftraj:allowed_status=["idle","completed"]ifallow_completed_lockelse["idle"]iftraj.statusinallowed_status:traj.status="locked"session.commit()returnTruereturnFalseerr_msg=f"Trajectory {traj_id} does not exist"_logger.error(err_msg)raiseValueError(err_msg)exceptSQLAlchemyError:session.rollback()_logger.exception("Failed to lock trajectory")raisefinally:session.close()
[docs]defmark_trajectory_as_completed(self,traj_id:int)->None:"""Set the status of a trajectory to "completed" if possible. Args: traj_id : The trajectory id Raises: ValueError if the trajectory with the given id does not exist SQLAlchemyError if the DB could not be accessed """session=self._Session()try:# SQL indexing starts at 1, adjust IDdb_id=traj_id+1traj=session.query(Trajectory).filter(Trajectory.id==db_id).one_or_none()iftraj:iftraj.statusin["locked"]:traj.status="completed"session.commit()else:warn_msg=f"Attempting to mark completed Trajectory {traj_id} already in status {traj.status}."_logger.warning(warn_msg)else:err_msg=f"Trajectory {traj_id} does not exist"_logger.error(err_msg)raiseValueError(err_msg)exceptSQLAlchemyError:session.rollback()_logger.exception("Failed to mark trajectory as completed")raisefinally:session.close()
[docs]defrelease_trajectory(self,traj_id:int)->None:"""Set the status of a trajectory to "idle" if possible. Args: traj_id : The trajectory id Raises: ValueError if the trajectory with the given id does not exist """session=self._Session()try:# SQL indexing starts at 1, adjust IDdb_id=traj_id+1traj=session.query(Trajectory).filter(Trajectory.id==db_id).one_or_none()iftraj:iftraj.statusin["locked"]:traj.status="idle"session.commit()else:warn_msg=f"Attempting to release Trajectory {traj_id} already in status {traj.status}."_logger.warning(warn_msg)else:err_msg=f"Trajectory {traj_id} does not exist"_logger.error(err_msg)raiseValueError(err_msg)exceptSQLAlchemyError:session.rollback()_logger.exception("Failed to release trajectory")raisefinally:session.close()
[docs]defget_trajectory_count(self)->int:"""Get the number of trajectories in the DB. Returns: The number of trajectories """session=self._Session()try:returnsession.query(Trajectory).count()exceptSQLAlchemyError:session.rollback()_logger.exception("Failed to count the number of trajectories")raisefinally:session.close()
[docs]deffetch_trajectory(self,traj_id:int)->str:"""Get the trajectory file of a trajectory. Args: traj_id : The trajectory id Return: The trajectory file Raises: ValueError if the trajectory with the given id does not exist """session=self._Session()try:# SQL indexing starts at 1, adjust IDdb_id=traj_id+1traj=session.query(Trajectory).filter(Trajectory.id==db_id).one_or_none()iftraj:tfile:str=traj.traj_filereturntfileerr_msg=f"Trajectory {traj_id} does not exist"_logger.error(err_msg)raiseValueError(err_msg)exceptSQLAlchemyError:session.rollback()_logger.exception("Failed to fetch trajectory")raisefinally:session.close()
[docs]defrelease_all_trajectories(self)->None:"""Release all trajectories in the DB."""session=self._Session()try:session.query(Trajectory).update({"status":"idle"})session.commit()exceptSQLAlchemyError:session.rollback()_logger.exception("Failed to release all trajectories")finally:session.close()
[docs]defarchive_trajectory(self,traj_file:str)->None:"""Add a new trajectory to the archive container. Args: traj_file : The trajectory file of that trajectory """session=self._Session()try:new_traj=ArchivedTrajectory(traj_file=traj_file)session.add(new_traj)session.commit()exceptSQLAlchemyError:session.rollback()_logger.exception("Failed to archive trajectory")finally:session.close()
[docs]deffetch_archived_trajectory(self,traj_id:int)->str:"""Get the trajectory file of a trajectory in the archive. Args: traj_id : The trajectory id Return: The trajectory file Raises: ValueError if the trajectory with the given id does not exist """session=self._Session()try:# SQL indexing starts at 1, adjust IDdb_id=traj_id+1traj=session.query(ArchivedTrajectory).filter(ArchivedTrajectory.id==db_id).one_or_none()iftraj:tfile:str=traj.traj_filereturntfileerr_msg=f"Trajectory {traj_id} does not exist"_logger.error(err_msg)raiseValueError(err_msg)exceptSQLAlchemyError:session.rollback()_logger.exception("Failed to fetch archived trajectory")raisefinally:session.close()
[docs]defget_archived_trajectory_count(self)->int:"""Get the number of trajectories in the archive. Returns: The number of trajectories """session=self._Session()try:returnsession.query(ArchivedTrajectory).count()exceptSQLAlchemyError:session.rollback()_logger.exception("Failed to count the number of archived trajectories")raisefinally:session.close()
[docs]defdump_file_json(self)->None:"""Dump the content of the trajectory table to a json file."""db_data={}session=self._Session()try:db_data["trajectories"]={traj.id-1:{"file":traj.traj_file,"status":traj.status}fortrajinsession.query(Trajectory).all()}db_data["archived_trajectories"]={traj.id-1:{"file":traj.traj_file}fortrajinsession.query(ArchivedTrajectory).all()}exceptSQLAlchemyError:session.rollback()_logger.exception("Failed to count the number of archived trajectories")raisefinally:session.close()json_file=Path(f"{Path(self._file_name).stem}.json")withjson_file.open("w")asf:json.dump(db_data,f,indent=2)