Coverage for lisacattools/plugins/mbh.py: 75%

191 statements  

« prev     ^ index     » next       coverage.py v7.0.5, created at 2023-01-13 12:12 +0000

1# -*- coding: utf-8 -*- 

2# Copyright (C) 2021 - James I. Thorpe, Tyson B. Littenberg, Jean-Christophe 

3# Malapert 

4# 

5# This file is part of lisacattools. 

6# 

7# lisacattools is free software: you can redistribute it and/or modify 

8# it under the terms of the GNU General Public License as published by 

9# the Free Software Foundation, either version 3 of the License, or 

10# (at your option) any later version. 

11# 

12# lisacattools is distributed in the hope that it will be useful, 

13# but WITHOUT ANY WARRANTY; without even the implied warranty of 

14# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

15# GNU General Public License for more details. 

16# 

17# You should have received a copy of the GNU General Public License 

18# along with lisacattools. If not, see <https://www.gnu.org/licenses/>. 

19"""Module implemented the MBH catalog.""" 

20import glob 

21import logging 

22import os 

23from itertools import chain 

24from typing import List 

25from typing import Optional 

26from typing import Union 

27 

28import numpy as np 

29import pandas as pd 

30 

31from ..catalog import GWCatalog 

32from ..catalog import GWCatalogs 

33from ..catalog import UtilsLogs 

34from ..catalog import UtilsMonitoring 

35 

36UtilsLogs.addLoggingLevel("TRACE", 15) 

37 

38 

39class MbhCatalogs(GWCatalogs): 

40 """Implementation of the MBH catalogs.""" 

41 

42 EXTRA_DIR = "extra_directories" 

43 

44 def __init__( 

45 self, 

46 path: str, 

47 accepted_pattern: Optional[str] = "MBH_wk*C.h5", 

48 rejected_pattern: Optional[str] = None, 

49 *args, 

50 **kwargs, 

51 ): 

52 """Init the MbhCatalogs by reading all catalogs with a specific 

53 pattern in a given directory and rejecting files by another pattern. 

54 

55 The list of catalogs is sorted by "observation week" 

56 

57 Args: 

58 path (str): directory 

59 accepted_pattern (str, optional): pattern to accept files. 

60 Defaults to "MBH_wk*C.h5". 

61 rejected_pattern (str, optional): pattern to reject files. 

62 Defaults to None. 

63 

64 Raises: 

65 ValueError: no files found matching the accepted and rejected 

66 patterns. 

67 """ 

68 self.path = path 

69 self.accepted_pattern = accepted_pattern 

70 self.rejected_pattern = rejected_pattern 

71 self.extra_directories = ( 

72 kwargs[MbhCatalogs.EXTRA_DIR] 

73 if MbhCatalogs.EXTRA_DIR in kwargs 

74 else list() 

75 ) 

76 directories = self._search_directories( 

77 self.path, self.extra_directories 

78 ) 

79 self.cat_files = self._search_files( 

80 directories, accepted_pattern, rejected_pattern 

81 ) 

82 if len(self.cat_files) == 0: 

83 raise ValueError( 

84 f"no files found matching the accepted \ 

85 ({self.accepted_pattern}) and rejected \ 

86 ({self.rejected_pattern}) patterns in {directories}" 

87 ) 

88 self.__metadata = pd.concat( 

89 [self._read_cats(cat_file) for cat_file in self.cat_files] 

90 ) 

91 self.__metadata = self.__metadata.sort_values(by="observation week") 

92 

93 @UtilsMonitoring.io(level=logging.DEBUG) 

94 def _search_directories( 

95 self, path: str, extra_directories: List[str] 

96 ) -> List[str]: 

97 """Compute the list of directories on which the pattern will be applied. 

98 

99 Args: 

100 path (str) : main path 

101 extra_directories (List[str]) : others directories 

102 

103 Returns: 

104 List[str]: list of directories on which the pattern will be applied 

105 """ 

106 directories: List[str] = extra_directories[:] 

107 directories.append(path) 

108 return directories 

109 

110 @UtilsMonitoring.io(level=logging.DEBUG) 

111 def _search_files( 

112 self, directories: List[str], accepted_pattern, rejected_pattern 

113 ) -> List[str]: 

114 """Search files in directories according to a set of constraints : 

115 accepted and rejected patterns 

116 

117 Args: 

118 directories (List[str]): List of directories to scan 

119 accepted_pattern ([type]): pattern to get files 

120 rejected_pattern ([type]): pattern to reject files 

121 

122 Returns: 

123 List[str]: List of files 

124 """ 

125 accepted_files = [ 

126 glob.glob(path + os.path.sep + accepted_pattern) 

127 for path in directories 

128 ] 

129 accepted_files = list(chain(*accepted_files)) 

130 if rejected_pattern is None: 

131 rejected_files = list() 

132 else: 

133 rejected_files = [ 

134 list() 

135 if rejected_pattern is None 

136 else glob.glob(path + os.path.sep + rejected_pattern) 

137 for path in directories 

138 ] 

139 rejected_files = list(chain(*rejected_files)) 

140 cat_files = list(set(accepted_files) - set(rejected_files)) 

141 return cat_files 

142 

143 @UtilsMonitoring.io(level=logging.DEBUG) 

144 def _read_cats(self, cat_file: str) -> pd.DataFrame: 

145 """Reads the metadata of a given catalog and the location of the file. 

146 

147 Args: 

148 cat_file (str): catalog to load 

149 

150 Returns: 

151 pd.DataFrame: pandas data frame 

152 """ 

153 df = pd.read_hdf(cat_file, key="metadata") 

154 df["location"] = cat_file 

155 return df 

156 

157 @property 

158 @UtilsMonitoring.io(level=logging.DEBUG) 

159 def metadata(self) -> pd.DataFrame: 

160 __doc__ = GWCatalogs.metadata.__doc__ # noqa: F841 

161 return self.__metadata 

162 

163 @property 

164 @UtilsMonitoring.io(level=logging.TRACE) 

165 def count(self) -> int: 

166 __doc__ = GWCatalogs.count.__doc__ # noqa: F841 

167 return len(self.metadata.index) 

168 

169 @property 

170 @UtilsMonitoring.io(level=logging.TRACE) 

171 def files(self) -> List[str]: 

172 __doc__ = GWCatalogs.files.__doc__ # noqa: F841 

173 return self.cat_files 

174 

175 @UtilsMonitoring.io(level=logging.TRACE) 

176 def get_catalogs_name(self) -> List[str]: 

177 __doc__ = GWCatalogs.get_catalogs_name.__doc__ # noqa: F841 

178 return list(self.metadata.index) 

179 

180 @UtilsMonitoring.io(level=logging.TRACE) 

181 @UtilsMonitoring.time_spend(level=logging.DEBUG, threshold_in_ms=10) 

182 def get_first_catalog(self) -> GWCatalog: 

183 __doc__ = GWCatalogs.get_first_catalog.__doc__ # noqa: F841 

184 location = self.metadata.iloc[0]["location"] 

185 name = self.metadata.index[0] 

186 return MbhCatalog(name, location) 

187 

188 @UtilsMonitoring.io(level=logging.TRACE) 

189 @UtilsMonitoring.time_spend(level=logging.DEBUG, threshold_in_ms=10) 

190 def get_last_catalog(self) -> GWCatalog: 

191 __doc__ = GWCatalogs.get_last_catalog.__doc__ # noqa: F841 

192 location = self.metadata.iloc[self.count - 1]["location"] 

193 name = self.metadata.index[self.count - 1] 

194 return MbhCatalog(name, location) 

195 

196 @UtilsMonitoring.io(level=logging.TRACE) 

197 @UtilsMonitoring.time_spend(level=logging.DEBUG, threshold_in_ms=10) 

198 def get_catalog(self, idx: int) -> GWCatalog: 

199 __doc__ = GWCatalogs.get_catalog.__doc__ # noqa: F841 

200 location = self.metadata.iloc[idx]["location"] 

201 name = self.metadata.index[idx] 

202 return MbhCatalog(name, location) 

203 

204 @UtilsMonitoring.io(level=logging.TRACE) 

205 @UtilsMonitoring.time_spend(level=logging.DEBUG, threshold_in_ms=10) 

206 def get_catalog_by(self, name: str) -> GWCatalog: 

207 __doc__ = GWCatalogs.get_catalog_by.__doc__ # noqa: F841 

208 cat_idx = self.metadata.index.get_loc(name) 

209 return self.get_catalog(cat_idx) 

210 

211 @UtilsMonitoring.io(level=logging.TRACE) 

212 @UtilsMonitoring.time_spend(level=logging.DEBUG, threshold_in_ms=100) 

213 def get_lineage(self, cat_name: str, src_name: str) -> pd.DataFrame: 

214 __doc__ = GWCatalogs.get_lineage.__doc__ # noqa: F841 

215 

216 dfs: List[pd.Series] = list() 

217 while src_name != "" and cat_name not in [None, ""]: 

218 detections = self.get_catalog_by(cat_name).get_dataset( 

219 "detections" 

220 ) 

221 src = detections.loc[[src_name]] 

222 try: 

223 wk = self.metadata.loc[cat_name]["observation week"] 

224 except: # noqa: E722 

225 wk = self.metadata.loc[cat_name]["Observation Week"] 

226 

227 src.insert(0, "Observation Week", wk, True) 

228 src.insert(1, "Catalog", cat_name, True) 

229 dfs.append(src) 

230 try: 

231 prnt = self.metadata.loc[cat_name]["parent"] 

232 except: # noqa: E722 

233 prnt = self.metadata.loc[cat_name]["Parent"] 

234 

235 cat_name = prnt 

236 src_name = src.iloc[0]["Parent"] 

237 

238 histDF: pd.DataFrame = pd.concat(dfs, axis=0) 

239 histDF.drop_duplicates( 

240 subset="Log Likelihood", keep="last", inplace=True 

241 ) 

242 histDF.sort_values(by="Observation Week", ascending=True, inplace=True) 

243 return histDF 

244 

245 @UtilsMonitoring.io(level=logging.TRACE) 

246 @UtilsMonitoring.time_spend(level=logging.DEBUG, threshold_in_ms=100) 

247 def get_lineage_data(self, lineage: pd.DataFrame) -> pd.DataFrame: 

248 __doc__ = GWCatalogs.get_lineage_data.__doc__ # noqa: F841 

249 

250 def _process_lineage(source_epoch, source_data, obs_week): 

251 source_data.insert( 

252 len(source_data.columns), "Source", source_epoch, True 

253 ) 

254 source_data.insert( 

255 len(source_data.columns), "Observation Week", obs_week, True 

256 ) 

257 return source_data 

258 

259 source_epochs = list(lineage.index) 

260 

261 merge_source_epochs: pd.DataFrame = pd.concat( 

262 [ 

263 _process_lineage( 

264 source_epoch, 

265 self.get_catalog_by( 

266 lineage.loc[source_epoch]["Catalog"] 

267 ).get_source_samples(source_epoch), 

268 lineage.loc[source_epoch]["Observation Week"], 

269 ) 

270 for source_epoch in source_epochs 

271 ] 

272 ) 

273 merge_source_epochs = merge_source_epochs[ 

274 [ 

275 "Source", 

276 "Observation Week", 

277 "Mass 1", 

278 "Mass 2", 

279 "Spin 1", 

280 "Spin 2", 

281 "Ecliptic Latitude", 

282 "Ecliptic Longitude", 

283 "Luminosity Distance", 

284 "Barycenter Merge Time", 

285 "Merger Phase", 

286 "Polarization", 

287 "cos inclination", 

288 ] 

289 ].copy() 

290 return merge_source_epochs 

291 

292 def __repr__(self): 

293 return f"MbhCatalogs({self.path!r}, {self.accepted_pattern!r}, \ 

294 {self.rejected_pattern!r}, {self.extra_directories!r})" 

295 

296 def __str__(self): 

297 return f"MbhCatalogs: {self.path} {self.accepted_pattern!r} \ 

298 {self.rejected_pattern!r} {self.extra_directories!r}" 

299 

300 

301class MbhCatalog(GWCatalog): 

302 """Implementation of the Mbh catalog.""" 

303 

304 def __init__(self, name: str, location: str): 

305 """Init the MBH catalog with a name and a location 

306 

307 Args: 

308 name (str): name of the catalog 

309 location (str): location of the catalog 

310 """ 

311 self.__name = name 

312 self.__location = location 

313 store = pd.HDFStore(location, "r") 

314 self.__datasets = store.keys() 

315 store.close() 

316 

317 @property 

318 @UtilsMonitoring.io(level=logging.DEBUG) 

319 def datasets(self): 

320 """dataset. 

321 

322 :getter: Returns the list of datasets 

323 :type: List 

324 """ 

325 return self.__datasets 

326 

327 @UtilsMonitoring.io(level=logging.DEBUG) 

328 def get_dataset(self, name: str) -> pd.DataFrame: 

329 """Returns a dataset based on its name. 

330 

331 Args: 

332 name (str): name of the dataset 

333 

334 Returns: 

335 pd.DataFrame: the dataset 

336 """ 

337 return pd.read_hdf(self.location, key=name) 

338 

339 @property 

340 @UtilsMonitoring.io(level=logging.DEBUG) 

341 def name(self) -> str: 

342 __doc__ = GWCatalog.name.__doc__ # noqa: F841 

343 return self.__name 

344 

345 @property 

346 @UtilsMonitoring.io(level=logging.DEBUG) 

347 def location(self) -> str: 

348 __doc__ = GWCatalog.location.__doc__ # noqa: F841 

349 return self.__location 

350 

351 @UtilsMonitoring.io(level=logging.DEBUG) 

352 def get_detections( 

353 self, attr: Union[List[str], str] = None 

354 ) -> Union[List[str], pd.DataFrame, pd.Series]: 

355 __doc__ = GWCatalog.get_detections.__doc__ # noqa: F841 

356 detections = self.get_dataset("detections") 

357 return ( 

358 list(detections.index) if attr is None else detections[attr].copy() 

359 ) 

360 

361 @UtilsMonitoring.io(level=logging.DEBUG) 

362 def get_attr_detections(self) -> List[str]: 

363 __doc__ = GWCatalog.get_attr_detections.__doc__ # noqa: F841 

364 return list(self.get_dataset("detections").columns) 

365 

366 @UtilsMonitoring.io(level=logging.DEBUG) 

367 def get_median_source(self, attr: str) -> pd.DataFrame: 

368 __doc__ = GWCatalog.get_median_source.__doc__ # noqa: F841 

369 detections: pd.Series = self.get_detections(attr) 

370 source_idx = self.get_detections()[ 

371 np.argmin(np.abs(np.array(detections) - detections.median())) 

372 ] 

373 return self.get_detections(self.get_attr_detections()).loc[ 

374 [source_idx] 

375 ] 

376 

377 @UtilsMonitoring.io(level=logging.DEBUG) 

378 def get_source_samples( 

379 self, source_name: str, attr: List[str] = None 

380 ) -> pd.DataFrame: 

381 __doc__ = GWCatalog.get_source_samples.__doc__ # noqa: F841 

382 samples = self.get_dataset(f"{source_name}_chain") 

383 return samples if attr is None else samples[attr].copy() 

384 

385 @UtilsMonitoring.io(level=logging.DEBUG) 

386 def get_attr_source_samples(self, source_name: str) -> List[str]: 

387 __doc__ = GWCatalog.get_attr_source_samples.__doc__ # noqa: F841 

388 return list(self.get_dataset(f"{source_name}_chain").columns) 

389 

390 @UtilsMonitoring.io(level=logging.TRACE) 

391 def describe_source_samples(self, source_name: str) -> pd.DataFrame: 

392 __doc__ = GWCatalog.describe_source_samples.__doc__ # noqa: F841 

393 return self.get_source_samples(source_name).describe() 

394 

395 def __repr__(self): 

396 return f"MbhCatalog({self.__name!r}, {self.__location!r})" 

397 

398 def __str__(self): 

399 return f"MbhCatalog: {self.__name} {self.__location}"