pyrevs.diagnostics.diagnostic

Diagnostic class for pyREVS.

Classes

DiagnosticPlugin

A base class for diagnostic plugins.

FirstTimeCrossingDiagnostic

Triggers when the score function crosses pre-defined levels.

DiagnosticAnalyst

A class to handle analysing diagnostic statistics.

Functions

diagnosticfactory(→ list[DiagnosticPlugin])

Parse input parameters to generate a list of DiagnosticPlugin.

Module Contents

class DiagnosticPlugin(dlabel: str, params: dict[Any, Any], tid: int, weight: float, workdir: pathlib.Path, fprocess: collections.abc.Callable[Ellipsis, Any], ddb: pyrevs.diagnostics.diagdb.DiagDB)[source]

A base class for diagnostic plugins.

Plugins are attached to the trajectory objects.

Variables:
  • _label – the diagnostic label

  • _tid – the ID of the trajectory the plugin is attached to

  • _weight – the weight of the trajectory

abstractmethod get_crossed_levels(new_snapshot: pyrevs.core.Snapshot) list[float][source]

Test to know if diagnostic is needed.

update(old_snapshot: pyrevs.core.Snapshot, new_snapshot: pyrevs.core.Snapshot) None[source]

Standard entry point called after every MCMC step.

class FirstTimeCrossingDiagnostic(dlabel: str, params: dict[Any, Any], tid: int, tweight: float, workdir: pathlib.Path, fprocess: collections.abc.Callable[Ellipsis, Any], ddb: pyrevs.diagnostics.diagdb.DiagDB)[source]

Bases: DiagnosticPlugin

Triggers when the score function crosses pre-defined levels.

This is a central diagnostic plugin to pyREVS, triggered when the score function crosses pre-defined levels, only for the first time.

This allows to evaluate the probability of crossing any intermediate score levels from a pyREVS run, as well as estimating mean first passage time.

Variables:
  • _levels – the threshold levels

  • _highest_recorded_score – the high water mark of the plugin

  • _checked_db – True if the DB has been checked

get_crossed_levels(new_snapshot: pyrevs.core.Snapshot) list[float][source]

Get the list of level crossed during last step.

Parameters:

new_snapshot – the new (end of the time step) snapshot

Returns:

the list of score levels crossed

diagnosticfactory(configs: dict[str, pyrevs.core.Config], tid: int, tweight: float, workdir: pathlib.Path, fprocess: collections.abc.Callable[Ellipsis, Any], ddb: pyrevs.diagnostics.diagdb.DiagDB) list[DiagnosticPlugin][source]

Parse input parameters to generate a list of DiagnosticPlugin.

Parameters:
  • configs – a dict with a Config object for each diagnostic

  • tid – the ID of the traj the diagnostic is attached to

  • tweight – the weight of the traj

  • workdir – the workdir associated with a trajectory

  • fprocess – the forward model diagnostic function

  • ddb – the diagnostic database to add the data to

class DiagnosticAnalyst(db_path: str)[source]

A class to handle analysing diagnostic statistics.

Let’s keep the analysis logic separated from the gathering logic. This class retrieves data from the diagnostic database and perform some computation (mostly conditional statistics on score iso-levels).

db[source]
get_diagnostic_data(label: str) dict[float, list[tuple[Any, float, float, int]]][source]

A user-facing access to the diag DB.

An alias to the DB access for the analyst.

Returns:

A dictionary mapping each score iso-level (float) to a list of tuples. Each tuple contains (unpickled_data, trajectory_weight, time, tid).

compute_weighted_stats(label: str) dict[float, dict[str, Any]][source]

Aggregate data and compute mean/variance per level.

get_conditional_means(label: str) dict[float, dict[str, Any]][source]

Compute means at level L_i conditioned on reaching or not reaching L_{i+1}.

Returns:

{“reached_next”: mean_val, “failed_next”: mean_val} }

Return type:

{ level