"""A set of utility functions for TAMS."""importloggingfromtypingimportAnyimportnumpyasnpimportnumpy.typingasnpt
[docs]defsetup_logger(params:dict[Any,Any])->None:"""Setup the logger parameters. Args: params: a dictionary of parameters """# Set logging levellog_level_str=params["tams"].get("loglevel","INFO")iflog_level_str.upper()=="DEBUG":log_level=logging.DEBUGeliflog_level_str.upper()=="INFO":log_level=logging.INFOeliflog_level_str.upper()=="WARNING":log_level=logging.WARNINGeliflog_level_str.upper()=="ERROR":log_level=logging.ERRORlog_format="[%(levelname)s] %(asctime)s - %(message)s"# Set root loggerlogging.basicConfig(level=log_level,format=log_format,)# Add file handler to root loggerifparams["tams"].get("logfile",None):log_file=logging.FileHandler(params["tams"]["logfile"])log_file.setLevel(log_level)log_file.setFormatter(logging.Formatter(log_format))logging.getLogger("").addHandler(log_file)
[docs]defget_min_scored(maxes:npt.NDArray[Any],nworkers:int)->tuple[list[int],npt.NDArray[Any]]:"""Get the nworker lower scored trajectories or more if equal score. Args: maxes: array of maximas accros all trajectories nworkers: number of workers Returns: list of indices of the nworker lower scored trajectories array of minimas """ordered_tlist=np.argsort(maxes)is_same_min=Falsemin_idx_list:list[int]=[]foridxinordered_tlist:iflen(min_idx_list)>0:is_same_min=maxes[idx]==maxes[min_idx_list[-1]]iflen(min_idx_list)<nworkersoris_same_min:min_idx_list.append(idx)min_vals=maxes[min_idx_list]returnmin_idx_list,min_vals
[docs]defmoving_avg(arr_in:npt.NDArray[Any],window_l:int)->npt.NDArray[Any]:"""Return the moving average of a 1D numpy array. Args: arr_in: 1D numpy array window_l: length of the moving average window Returns: 1D numpy array """arr_out=np.zeros(arr_in.shape[0])foriinrange(len(arr_in)):lbnd=max(i-int(np.ceil(window_l/2)),0)hbnd=min(i+int(np.floor(window_l/2)),len(arr_in)-1)iflbnd==0:hbnd=window_lifhbnd==len(arr_in)-1:lbnd=len(arr_in)-window_l-1arr_out[i]=np.mean(arr_in[lbnd:hbnd])returnarr_out