"""A class for the TAMS data as an SQL database using SQLAlchemy."""
from __future__ import annotations
import json
import logging
import pickle
from pathlib import Path
from typing import Any
from typing import cast
import numpy as np
from sqlalchemy import Boolean
from sqlalchemy import CursorResult
from sqlalchemy import Float
from sqlalchemy import LargeBinary
from sqlalchemy import func
from sqlalchemy import select
from sqlalchemy import update
from sqlalchemy.orm import DeclarativeBase
from sqlalchemy.orm import Mapped
from sqlalchemy.orm import mapped_column
from pytams.sqlmanager import BaseSQLManager
_logger = logging.getLogger(__name__)
[docs]
class DiagBase(DeclarativeBase):
"""A base class for the tables."""
[docs]
class DiagnosticEntry(DiagBase):
"""Table for recording model data at specific score levels."""
__tablename__ = "diagnostics"
[docs]
id: Mapped[int] = mapped_column(primary_key=True, autoincrement=True)
[docs]
traj_id: Mapped[int] = mapped_column(nullable=False)
[docs]
level_crossed: Mapped[float] = mapped_column(Float, nullable=False)
[docs]
time: Mapped[float] = mapped_column(Float, nullable=False)
[docs]
weight: Mapped[float] = mapped_column(Float, nullable=False)
[docs]
model_data: Mapped[bytes] = mapped_column(LargeBinary, nullable=False)
[docs]
active: Mapped[bool] = mapped_column(Boolean, nullable=False)
[docs]
diaglabel: Mapped[str] = mapped_column(nullable=False)
[docs]
class DiagDB(BaseSQLManager):
"""A database to keep track of the diagnostics data.
Diagnostic entries are agregated in single table. Each entry
is associated to a trajectory and its weight in the ensemble.
"""
def __init__(self, file_name: str, in_memory: bool = False, ro_mode: bool = False) -> None:
"""Initialize the file.
Args:
file_name : The file name
in_memory: a bool to trigger in-memory creation
ro_mode: a bool to trigger read-only access to the database
"""
super().__init__(file_name, DiagBase.metadata, in_memory, ro_mode)
[docs]
def add_diagnostic_entry(
self,
diaglabel: str,
traj_id: int,
level: float,
time: float,
weight: float,
ldata: bytes,
) -> None:
"""Atomic insert of a diagnostic snapshot.
The data schema assumes that any new addition to the database is
made on an active trajectory.
Args:
diaglabel: the label of the diagnostic inserting the entry
traj_id: the ID of the traj adding the entry
level: the score level of the entry
time: the trajectory time at which the diagnostic was triggered
weight: the weight of the trajectory
ldata: the actual model data stored in the database
"""
with self.session_scope() as session:
entry = DiagnosticEntry(
diaglabel=diaglabel,
traj_id=traj_id,
level_crossed=level,
time=time,
weight=weight,
active=True,
model_data=ldata,
)
session.add(entry)
[docs]
def get_highest_recorded_level(self, traj_id: int, label: str) -> float:
"""Return the maximum level already recorded for this traj/label.
Args:
traj_id: the ID of a trajectory
label: the label of the diagnostic targeter
Returns:
the highest value of level_crossed
"""
with self.session_scope() as session:
# Assuming your DiagnosticEntry model has these columns
stmt = (
select(func.max(DiagnosticEntry.level_crossed))
.where(DiagnosticEntry.traj_id == traj_id)
.where(DiagnosticEntry.diaglabel == label)
)
result = session.scalar(stmt)
return float(result) if result is not None else -np.inf
[docs]
def duplicate_diagnostic_history(
self,
ancestor_id: int,
discarded_id: int,
new_id: int,
new_weight: float,
threshold: float,
) -> int:
"""Copy diagnostic entries from an ancestor to a descendant.
Copies all entries where level_crossed <= threshold.
Returns the number of entries duplicated.
The entries belonging to the discarded trajectory are set
to inactive.
Args:
ancestor_id: the ID of the ancestor to copy data from
discarded_id: the ID of the discarded trajectory (during TAMS iterations)
new_id: the ID of the new child trajectory
new_weight: the weight of the new child trajectory
threshold: the score threshold up to which copy must be performed
"""
with self.session_scope() as session:
# Set the discarded trajectory to inactive
stmt_update = (
update(DiagnosticEntry)
.where(
DiagnosticEntry.traj_id == discarded_id,
)
.values(active=False)
)
session.execute(stmt_update)
# Select the relevant entries from the ancestor
# Fetched as dictionaries to easily modify them for insertion
stmt = select(DiagnosticEntry).where(
DiagnosticEntry.traj_id == ancestor_id, DiagnosticEntry.level_crossed <= threshold
)
ancestor_entries = session.execute(stmt).scalars().all()
if not ancestor_entries:
return 0
new_entries = []
for entry in ancestor_entries:
# Create a new entry object (stripping the original primary key 'id')
new_entry = DiagnosticEntry(
diaglabel=entry.diaglabel,
traj_id=new_id,
level_crossed=entry.level_crossed,
time=entry.time,
weight=new_weight,
active=True,
model_data=entry.model_data,
)
new_entries.append(new_entry)
session.add_all(new_entries)
return len(new_entries)
[docs]
def update_all_active_weights(self, new_weight: float) -> int:
"""Update all the active trajectories weight.
Args:
new_weight: the updated weight
Returns:
the number of trajectory updated
"""
with self.session_scope() as session:
stmt_update = (
update(DiagnosticEntry)
.where(
DiagnosticEntry.active,
)
.values(weight=new_weight)
)
result = session.execute(stmt_update)
return int(cast("CursorResult", result).rowcount or 0)
[docs]
def get_diagnostic_data(self, label: str) -> dict[float, list[tuple[Any, float]]]:
"""Retrieve all diagnostic snapshots for a specific label.
Args:
label: the label of the diagnostic of interest
Returns:
A dictionary mapping each iso-level (float) to a list of tuples.
Each tuple contains (unpickled_data, trajectory_weight).
"""
results_dict: dict[float, list[tuple[Any, float]]] = {}
with self.session_scope() as session:
# Query entries for the specific label, ordered by level
stmt = (
select(DiagnosticEntry.level_crossed, DiagnosticEntry.weight, DiagnosticEntry.model_data)
.where(DiagnosticEntry.diaglabel == label)
.order_by(DiagnosticEntry.level_crossed.asc())
)
rows = session.execute(stmt).all()
for level, weight, blob in rows:
# Unpickle the model data
data = pickle.loads(blob) # noqa: S301
if level not in results_dict:
results_dict[level] = []
results_dict[level].append((data, weight))
return results_dict
[docs]
def dump_to_json(self, json_path: str) -> None:
"""Export the entire diagnostic database to a JSON file.
Note that the content of the data stored in the database is
omitted. Only the metadata of each stored data is dumpe for debuggin purposes.
"""
dump_data = []
with self.session_scope() as session:
# Fetch every entry in the database
stmt = select(DiagnosticEntry).order_by(DiagnosticEntry.traj_id, DiagnosticEntry.level_crossed)
results = session.execute(stmt).scalars().all()
for entry in results:
# Prepare the row dictionary
row = {
"diaglabel": entry.diaglabel,
"traj_id": entry.traj_id,
"level_crossed": float(entry.level_crossed),
"time": float(entry.time),
"weight": float(entry.weight),
"active": entry.active,
}
dump_data.append(row)
# Write to file with pretty printing
with Path(json_path).open("w") as f:
json.dump(dump_data, f, indent=4)
[docs]
def close(self) -> None:
"""Dispose of the engine and clear connections."""
if self._engine:
self._engine.dispose()