Coverage for lisacattools/analyze.py: 34%

190 statements  

« prev     ^ index     » next       coverage.py v7.0.5, created at 2023-01-13 12:12 +0000

1# -*- coding: utf-8 -*- 

2# Copyright (C) 2021 - James I. Thorpe, Tyson B. Littenberg, Jean-Christophe 

3# Malapert 

4# 

5# This file is part of lisacattools. 

6# 

7# lisacattools is free software: you can redistribute it and/or modify 

8# it under the terms of the GNU General Public License as published by 

9# the Free Software Foundation, either version 3 of the License, or 

10# (at your option) any later version. 

11# 

12# lisacattools is distributed in the hope that it will be useful, 

13# but WITHOUT ANY WARRANTY; without even the implied warranty of 

14# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

15# GNU General Public License for more details. 

16# 

17# You should have received a copy of the GNU General Public License 

18# along with lisacattools. If not, see <https://www.gnu.org/licenses/>. 

19import logging 

20import os 

21from typing import Dict 

22from typing import List 

23from typing import NoReturn 

24 

25import corner 

26import ligo.skymap.plot # noqa: F401 

27import matplotlib.pyplot as plt 

28import numpy as np 

29import pandas as pd 

30import seaborn as sns 

31 

32from .catalog import GWCatalog 

33from .catalog import GWCatalogs 

34from .custom_logging import UtilsLogs 

35from .monitoring import UtilsMonitoring 

36from .utils import FrameEnum 

37from .utils import HPhist 

38 

39UtilsLogs.addLoggingLevel("TRACE", 15) 

40 

41 

42class LisaAnalyse: 

43 """Factory to create an analysis for a catalog or a time-evolution of the 

44 catalog.""" 

45 

46 @staticmethod 

47 def create(catalog, save_dir=None): 

48 obj = None 

49 if isinstance(catalog, GWCatalog): 

50 obj = CatalogAnalysis(catalog, save_dir) 

51 elif isinstance(catalog, GWCatalogs): 

52 obj = HistoryAnalysis(catalog, save_dir) 

53 else: 

54 raise NotImplementedError(f"type {type(catalog)} not implemented") 

55 return obj 

56 

57 

58class AbstractLisaAnalyze: 

59 """Abstract Object to link the two implementation and to share some 

60 method.""" 

61 

62 def __init__(self): 

63 pass 

64 

65 @UtilsMonitoring.io(level=logging.DEBUG) 

66 def _get_variable( 

67 self, dico: Dict, variable: str, default_val: object 

68 ) -> object: 

69 return default_val if variable not in dico else dico[variable] 

70 

71 @UtilsMonitoring.io(entry=True, exit=False, level=logging.DEBUG) 

72 def plot_corners_ds(self, sources, *args, **kwargs): 

73 color = self._get_variable(kwargs, "color", "red") 

74 plot_datapoints = self._get_variable(kwargs, "plot_datapoints", False) 

75 fill_contours = self._get_variable(kwargs, "fill_contours", True) 

76 bins = self._get_variable(kwargs, "bins", 50) 

77 smooth = self._get_variable(kwargs, "smooth", 1.0) 

78 levels = self._get_variable(kwargs, "levels", [0.68, 0.95]) 

79 fontsize = self._get_variable(kwargs, "fontsize", 16) 

80 fig = self._get_variable(kwargs, "fig", None) 

81 title = self._get_variable(kwargs, "title", "parameters") 

82 if fig: 

83 corner.corner( 

84 sources, 

85 fig=fig, 

86 color=color, 

87 plot_datapoints=plot_datapoints, 

88 fill_contours=fill_contours, 

89 bins=bins, 

90 smooth=smooth, 

91 levels=levels, 

92 label_kwargs={"fontsize": fontsize}, 

93 ) 

94 else: 

95 figIn = corner.corner( 

96 sources, 

97 color=color, 

98 plot_datapoints=plot_datapoints, 

99 fill_contours=fill_contours, 

100 bins=bins, 

101 smooth=smooth, 

102 levels=levels, 

103 label_kwargs={"fontsize": fontsize}, 

104 ) 

105 figIn.suptitle(title) 

106 

107 

108class CatalogAnalysis(AbstractLisaAnalyze): 

109 """Handle the analysis of one catalog.""" 

110 

111 def __init__(self, catalog: GWCatalog, save_img_dir=None): 

112 """Init the analysis with a Lisa catalog.""" 

113 self.catalog = catalog 

114 self.save_img_dir = save_img_dir 

115 

116 @property 

117 def catalog(self): 

118 """Catalog. 

119 

120 :getter: Returns the catalog of this analysis 

121 :setter: Sets the catalog. 

122 :type: GWCatalog 

123 """ 

124 return self._catalog 

125 

126 @catalog.setter 

127 def catalog(self, value): 

128 self._catalog = value 

129 

130 @property 

131 def save_img_dir(self): 

132 """Save image directory for plot. 

133 

134 :getter: Returns the directory where plots are saved 

135 :setter: Sets the directory where plots are saved. 

136 :type: str 

137 """ 

138 return self._save_img_dir 

139 

140 @save_img_dir.setter 

141 def save_img_dir(self, value): 

142 self._save_img_dir = value 

143 

144 @UtilsMonitoring.io(entry=True, exit=False, level=logging.DEBUG) 

145 def plot_mbh_mergers_history(self) -> NoReturn: 

146 """Plot the history of observed mergers.""" 

147 

148 mergeTimes = self.catalog.get_detections("Barycenter Merge Time") 

149 mergeTimes.sort_values(ascending=True, inplace=True) 

150 mergeT = np.insert(np.array(mergeTimes) / 86400, 0, 0) 

151 mergeCount = np.arange(0, len(mergeTimes) + 1) 

152 fig, ax = plt.subplots(figsize=[8, 6], dpi=100) 

153 ax.step(mergeT, mergeCount, where="post") 

154 for m in range(0, len(mergeTimes)): 

155 plt.annotate( 

156 mergeTimes.index[m], # this is the text 

157 # this is the point to label 

158 (mergeTimes[m] / 86400, mergeCount[m]), 

159 textcoords="offset points", # how to position the text 

160 xytext=(2, 5), # distance from text to points (x,y) 

161 rotation="horizontal", 

162 ha="left", 

163 ) # horizontal alignment can be left, right or center 

164 ax.set_xlabel("Observation Time [days]") 

165 ax.set_ylabel("Merger Count") 

166 ax.set_title(f"MBH Mergers in catalog {self.catalog.name}") 

167 ax.grid() 

168 if self.save_img_dir: 

169 fig.savefig( 

170 os.path.join( 

171 self.save_img_dir, 

172 "MBH_mergers_" + self.catalog.name + ".png", 

173 ) 

174 ) 

175 # plt.show() 

176 

177 @UtilsMonitoring.io(entry=True, exit=False, level=logging.DEBUG) 

178 def plot_individual_sources(self) -> NoReturn: 

179 """Plot the indivual sources.""" 

180 

181 fig, ax = plt.subplots(figsize=[8, 6], dpi=100) 

182 detections = self.catalog.get_detections(["Mass 1", "Mass 2"]) 

183 sources = list(detections.index) 

184 for idx, source in enumerate(sources): 

185 chain = self.catalog.get_source_samples( 

186 source, ["Mass 1", "Mass 2"] 

187 ) 

188 l1, m1, h1 = np.quantile( 

189 np.array(chain["Mass 1"]), [0.05, 0.5, 0.95] 

190 ) 

191 l2, m2, h2 = np.quantile( 

192 np.array(chain["Mass 2"]), [0.05, 0.5, 0.95] 

193 ) 

194 if idx < 10: 

195 mkr = "o" 

196 else: 

197 mkr = "^" 

198 ax.errorbar( 

199 m1, 

200 m2, 

201 xerr=np.vstack((m1 - l1, h1 - m1)), 

202 yerr=np.vstack((m2 - l2, h2 - m2)), 

203 label=source, 

204 markersize=6, 

205 capsize=2, 

206 marker=mkr, 

207 markerfacecolor="none", 

208 ) 

209 ax.set_xscale("log", nonpositive="clip") 

210 ax.set_yscale("log", nonpositive="clip") 

211 ax.grid() 

212 ax.set_xlabel("Mass 1 [MSun]") 

213 ax.set_ylabel("Mass 2 [MSun]") 

214 ax.set_title("90%% CI for Component Masses in %s " % self.catalog.name) 

215 ax.legend(loc="lower right") 

216 if self.save_img_dir: 

217 fig.savefig( 

218 os.path.join( 

219 self.save_img_dir, 

220 "component_masses" + self.catalog.name + ".png", 

221 ) 

222 ) 

223 # plt.show() 

224 

225 @UtilsMonitoring.io(entry=True, exit=False, level=logging.DEBUG) 

226 def plot_corners(self, source_name, params, *args, **kwargs) -> NoReturn: 

227 """Some corners plots.""" 

228 sources = self.catalog.get_source_samples(source_name, params) 

229 self.plot_corners_ds(source_name, sources, *args, **kwargs) 

230 

231 @UtilsMonitoring.io(entry=True, exit=False, level=logging.DEBUG) 

232 def plot_skymap( 

233 self, source, nside, system: FrameEnum = FrameEnum.ECLIPTIC 

234 ) -> NoReturn: 

235 """Plot skymap.""" 

236 hp_map = HPhist(source, nside, system) 

237 fig = plt.figure(figsize=(8, 6), dpi=100) 

238 ax = plt.axes( 

239 [0.05, 0.05, 0.9, 0.9], projection="geo degrees mollweide" 

240 ) 

241 ax.grid() 

242 ax.imshow_hpx((hp_map), cmap="plasma") 

243 if self.save_img_dir: 

244 fig.savefig(os.path.join(self.save_img_dir, "skymap.png")) 

245 

246 

247class HistoryAnalysis(AbstractLisaAnalyze): 

248 """Analyse a particular source to see how it's parameter estimates 

249 improve over time""" 

250 

251 def __init__(self, catalogs: GWCatalogs, save_img_dir=None): 

252 """Init the HistoryAnalysis with all catalogs to load the parameter 

253 estimates over the time.""" 

254 self.catalogs = catalogs 

255 self.save_img_dir = save_img_dir 

256 

257 @property 

258 def catalogs(self): 

259 """Catalogs. 

260 

261 :getter: Returns the catalogs of this analysis 

262 :setter: Sets the catalogs. 

263 :type: GWCatalogs 

264 """ 

265 return self._catalogs 

266 

267 @catalogs.setter 

268 def catalogs(self, value): 

269 self._catalogs = value 

270 

271 @property 

272 def save_img_dir(self): 

273 """Save image directory for plot. 

274 

275 :getter: Returns the directory where plots are saved 

276 :setter: Sets the directory where plots are saved. 

277 :type: str 

278 """ 

279 return self._save_img_dir 

280 

281 @save_img_dir.setter 

282 def save_img_dir(self, value): 

283 self._save_img_dir = value 

284 

285 @UtilsMonitoring.io(entry=True, exit=False, level=logging.DEBUG) 

286 def plot_parameter_time_evolution( 

287 self, 

288 df: pd.DataFrame, 

289 time_parameter: str, 

290 parameter: str, 

291 *args, 

292 **kwargs, 

293 ) -> NoReturn: 

294 """Plot the parameter that evolves over time. 

295 

296 Note: extra parameter can be configured: 

297 - plot_type, default : scatter 

298 - grid, default : True 

299 - marker, default : 's' 

300 - linestyle, default : '-' 

301 - yscale, default : log 

302 - title, default : Evolution 

303 

304 Args: 

305 df (pd.DataFrame): data 

306 time_parameter (str): time parameter in the data 

307 parameter (str): parameter to plot over the time 

308 """ 

309 plot_type = self._get_variable(kwargs, "scatter", "scatter") 

310 grid = self._get_variable(kwargs, "grid", True) 

311 marker = self._get_variable(kwargs, "marker", "s") 

312 linestyle = self._get_variable(kwargs, "linestyle", "-") 

313 yscale = self._get_variable(kwargs, "yscale", "log") 

314 title: str = self._get_variable(kwargs, "title", "Evolution") 

315 

316 fig, ax = plt.subplots(figsize=[8, 6], dpi=100) 

317 df.plot( 

318 kind=plot_type, 

319 x=time_parameter, 

320 y=parameter, 

321 ax=ax, 

322 grid=grid, 

323 marker=marker, 

324 linestyle=linestyle, 

325 ) 

326 ax.set_yscale(yscale) 

327 ax.set_title(title) 

328 if self.save_img_dir: 

329 fig.savefig( 

330 os.path.join( 

331 self.save_img_dir, title.replace(" ", "_") + ".png" 

332 ) 

333 ) 

334 

335 @UtilsMonitoring.io(entry=True, exit=False, level=logging.DEBUG) 

336 def plot_parameter_time_evolution_from_source( 

337 self, 

338 catalog_name: str, 

339 source_name: str, 

340 time_parameter: str, 

341 parameter: str, 

342 *args, 

343 **kwargs, 

344 ) -> NoReturn: 

345 """Plot the parameter that evolves over time for a given source 

346 starting from a catalog. 

347 

348 Note: extra parameter can be configured: 

349 - plot_type, default : scatter 

350 - grid, default : True 

351 - marker, default : 's' 

352 - linestyle, default : '-' 

353 - yscale, default : log 

354 - title, default : Evolution 

355 

356 Args: 

357 df (pd.DataFrame): data 

358 catalog_name (str) : Start the evolution from the oldest one 

359 until that one 

360 source_name (str) : source name to follow up 

361 time_parameter (str): time parameter in the data 

362 parameter (str): parameter to plot over the time 

363 """ 

364 catalogs = self.catalogs 

365 srcHist = catalogs.get_lineage(catalog_name, source_name) 

366 self.plot_parameter_time_evolution( 

367 srcHist, time_parameter, parameter, *args, **kwargs 

368 ) 

369 

370 @UtilsMonitoring.io(entry=True, exit=False, level=logging.DEBUG) 

371 def plot_parameters_evolution( 

372 self, 

373 all_epochs: pd.DataFrame, 

374 params: List, 

375 scales: List, 

376 *args, 

377 **kwargs, 

378 ) -> NoReturn: 

379 """Show evolution over many different epochs. 

380 

381 Args: 

382 all_epochs (pd.DataFrame): observation of a source at 

383 different epochs 

384 params (List): list of parameters to plot 

385 scales (List): Scale for each plot 

386 """ 

387 title = self._get_variable(kwargs, "title", "Parameter Evolution") 

388 x_title = self._get_variable(kwargs, "x_title", "Observation Week") 

389 nrows = int(np.ceil(len(params) / 2)) 

390 fig = plt.figure(figsize=(10.0, 10.0), dpi=100) 

391 

392 for idx, param in enumerate(params): 

393 ax = fig.add_subplot(nrows, 2, idx + 1) 

394 sns.violinplot( 

395 ax=ax, 

396 x=x_title, 

397 y=param, 

398 data=all_epochs, 

399 scale="width", 

400 width=0.8, 

401 inner="quartile", 

402 ) 

403 ax.set_yscale(scales[idx]) 

404 ax.grid(axis="y") 

405 

406 fig.suptitle(title) 

407 if self.save_img_dir: 

408 fig.savefig( 

409 os.path.join( 

410 self.save_img_dir, title.replace(" ", "_") + ".png" 

411 ) 

412 ) 

413 

414 @UtilsMonitoring.io(entry=True, exit=False, level=logging.DEBUG) 

415 def plot_parameters_correlation_evolution( 

416 self, 

417 allEpochs: pd.DataFrame, 

418 wks: List, 

419 params: List, 

420 colors: List, 

421 *args, 

422 **kwargs, 

423 ) -> NoReturn: 

424 """To dig into how parameter correlations might change over time, we 

425 can look at a time-evolving corner plot 

426 

427 Args: 

428 allEpochs (pd.DataFrame): observation of a source at different 

429 epochs 

430 wks (List): weeks to plot 

431 params (List): parameters to plot 

432 colors (List): color according the weeks 

433 """ 

434 title = self._get_variable(kwargs, "title", "Evolution of parameters") 

435 fig = plt.figure(figsize=[8, 8], dpi=100) 

436 for idx, wk in enumerate(wks): 

437 epoch = allEpochs[allEpochs["Observation Week"] == wk] 

438 self.plot_corners_ds(epoch[params], fig=fig, color=colors[idx]) 

439 fig.suptitle(title) 

440 if self.save_img_dir: 

441 fig.savefig( 

442 os.path.join( 

443 self.save_img_dir, title.replace(" ", "_") + ".png" 

444 ) 

445 ) 

446 

447 @UtilsMonitoring.io(entry=True, exit=False, level=logging.DEBUG) 

448 def plot_skymap_evolution( 

449 self, 

450 nside: int, 

451 allEpochs: pd.DataFrame, 

452 wks: List, 

453 system: FrameEnum = FrameEnum.GALACTIC, 

454 *args, 

455 **kwargs, 

456 ) -> NoReturn: 

457 """Plot the skymap evolution 

458 

459 Args: 

460 nside (int): parameter for healpix related to the number of cells 

461 allEpochs (pd.DataFrame): observation of a source at different 

462 epochs 

463 wks (List): weeks to plot 

464 system (FrameEnum, optional): coordinate reference frame. Defaults 

465 to 'FrameEnum.GALACTIC'. 

466 """ 

467 title = self._get_variable( 

468 kwargs, "title", "Sky Localization Evolution" 

469 ) 

470 fig = plt.figure(figsize=(10, 10), dpi=100) 

471 ncols = 2 

472 nrows = int(np.ceil(len(wks) / ncols)) 

473 for idx, wk in enumerate(wks): 

474 hpmap = HPhist( 

475 allEpochs[allEpochs["Observation Week"] == wk], nside, system 

476 ) 

477 ax = fig.add_subplot( 

478 nrows, ncols, idx + 1, projection="geo degrees mollweide" 

479 ) 

480 ax.grid() 

481 # ax.contour_hpx(hpmap, cmap='Blues',levels=4,alpha=0.8) 

482 ax.imshow_hpx(hpmap, cmap="plasma") 

483 ax.set_title(f"Week {wk}") 

484 fig.suptitle(title) 

485 if self.save_img_dir: 

486 fig.savefig( 

487 os.path.join( 

488 self.save_img_dir, title.replace(" ", "_") + ".png" 

489 ) 

490 )