Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
261 changes: 261 additions & 0 deletions agasc/scripts/obs_stats.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,261 @@
#!/usr/bin/env python

"""
Observation Statistics.


"""

import argparse
import logging
import os
import time
from multiprocessing import Pool
from pathlib import Path
from pprint import pformat

import numpy as np
import yaml
from astropy.table import Table, vstack
from cxotime import CxoTime
from cxotime import units as u
from ska_helpers import logging as ska_logging
from tqdm import tqdm

from agasc import agasc
from agasc.supplement.magnitudes import mag_estimate, star_obs_catalogs


def get_parser():
parser = argparse.ArgumentParser(
description=__doc__,
)
parser.add_argument(
"--start",
help=(
"Include only stars observed after this time."
" CxoTime-compatible time stamp."
" Default: now - 30 days."
),
)
parser.add_argument(
"--stop",
help=(
"Include only stars observed before this time."
" CxoTime-compatible time stamp."
" Default: now."
),
)
parser.add_argument(
"--output-dir",
help=("Directory where to write the result. Default: ."),
default=".",
)
parser.add_argument(
"--log-level", default="info", choices=["debug", "info", "warning", "error"]
)
parser.add_argument(
"--multiprocessing",
help="Use multiprocessing to speed up the processing of observations.",
action="store_true",
default=False,
)
parser.add_argument(
"--save-call-args",
help=(
"Save the input arguments to a YAML file in the output directory."
" The file name is call_args.yml or call_args.N.yml if the former already exists."
),
action="store_true",
default=False,
)
return parser


def get_args():
logger = ska_logging.basic_logger(
name="agasc.supplement",
level="WARNING",
format="%(asctime)s %(message)s",
)

the_parser = get_parser()
args = the_parser.parse_args()
logger.setLevel(args.log_level.upper())

args.output_dir = Path(os.path.expandvars(args.output_dir))

# set start/stop times
args.stop = CxoTime(args.stop).date if args.stop else CxoTime.now().date
args.start = (
CxoTime(args.start).date if args.start else (CxoTime.now() - 30 * u.day).date
)

if not args.output_dir.exists():
args.output_dir.mkdir(parents=True)

# save call args just in case
if args.save_call_args:
args_log_file = get_next_file_name(args.output_dir / "call_args.yml")
yaml_args = {
k: str(v) if issubclass(type(v), Path) else v for k, v in vars(args).items()
}
logger.info(f"Writing input arguments to {args_log_file}")
with open(args_log_file, "w") as fh:
yaml.dump(yaml_args, fh)

logger.info("Input arguments")
for line in pformat(yaml_args).split("\n"):
logger.info(line.rstrip())

return {
"output_dir": args.output_dir,
"start": args.start,
"stop": args.stop,
"multiprocessing": args.multiprocessing,
}


def get_next_file_name(file_name):
if not file_name.exists():
return file_name
i = 1
while True:
new_file_name = file_name.with_suffix(f".{i}{file_name.suffix}")
if not new_file_name.exists():
return new_file_name
i += 1


def get_multi_obs_stats(star_obs, obs_status_override):
telem = mag_estimate.get_telemetry_by_observations(
star_obs, ignore_exceptions=True, as_table=False
)

# Only keep telemetry from the first 2000 seconds of each observation
for tel in telem:
if "times" not in tel or len(tel["times"]) == 0:
continue
t0 = tel["times"][0]
sel = (tel["times"] - t0) < 2000
for k in tel:
tel[k] = tel[k][sel]

obs_stats, failures = mag_estimate.get_multi_obs_stats(
star_obs, telem=telem, obs_status_override=obs_status_override
)
return obs_stats, failures


def get_multi_obs_stats_pool(
star_obs, obs_status_override, batch_size=20, no_progress=None
):
"""
Call update_mag_stats.get_agasc_id_stats multiple times using a multiprocessing.Pool

:param star_obs: Table
:param obs_status_override: dict.
Dictionary overriding the OK flag for specific observations.
Keys are (OBSID, AGASC ID) pairs, values are dictionaries like
{'obs_ok': True, 'comments': 'some comment'}
:param batch_size: int
:param tstop: cxotime-compatible timestamp
Only observations prior to this timestamp are considered.
:return: astropy.table.Table, astropy.table.Table, list
obs_stats, agasc_stats, fails, failed_jobs
"""
logger = logging.getLogger("agasc.supplement")

jobs = []
args = []
finished = 0
logger.info(f"Processing {batch_size} observations per job")
for i in range(0, len(star_obs), batch_size):
args.append(star_obs[i : i + batch_size])

with Pool() as pool:
for arg in args:
jobs.append(
pool.apply_async(get_multi_obs_stats, [arg, obs_status_override])
)
bar = tqdm(total=len(jobs), desc="progress", disable=no_progress, unit="job")
while finished < len(jobs):
finished = sum([f.ready() for f in jobs])
if finished - bar.n:
bar.update(finished - bar.n)
time.sleep(1)
bar.close()

fails = []
for arg, job in zip(args, jobs):
if job.successful():
continue
try:
job.get()
except Exception as e:
for obs in arg:
fails.append(
dict(
mag_estimate.MagStatsException(
agasc_id=obs["agasc_id"],
obsid=obs["obsid"],
msg=f"Failed job: {e}",
)
)
)

results = [job.get() for job in jobs if job.successful()]

obs_stats = [r[0] for r in results if r[0] is not None]
obs_stats = vstack(obs_stats) if obs_stats else Table()
fails += sum([r[1] for r in results], [])

return obs_stats, fails


def main():
args = get_args()

star_obs_catalogs.load(args["stop"])

obs_status_override = {
(r["mp_starcat_time"], r["agasc_id"]): {
"status": r["status"],
"comments": r["comments"],
}
for r in agasc.get_supplement_table("obs")
}

obs_in_time = (star_obs_catalogs.STARS_OBS["mp_starcat_time"] >= args["start"]) & (
star_obs_catalogs.STARS_OBS["mp_starcat_time"] <= args["stop"]
)
star_obs = star_obs_catalogs.STARS_OBS[obs_in_time]

if args["multiprocessing"]:
obs_stats, failures = get_multi_obs_stats_pool(star_obs, obs_status_override)
else:
obs_stats, failures = get_multi_obs_stats(star_obs, obs_status_override)

if len(obs_stats) > 0:
print(f"Successfully processed {len(obs_stats)} observations.")
obs_stats.sort(["mp_starcat_time", "slot"])
obs_stats.write(args["output_dir"] / "obs_stats.fits", overwrite=True)
else:
print("No observations processed successfully.")

if failures:
# this cleans it up for YAML dumping, converting numpy types to native Python types where
# possible.
failures = [
{
k: (v.item() if isinstance(v, np.generic) and hasattr(v, "item") else v)
for k, v in d.items()
}
for d in failures
]
with open(args["output_dir"] / "obs_stats_failures.yml", "w") as fh:
yaml.dump(failures, fh)


if __name__ == "__main__":
main()
34 changes: 23 additions & 11 deletions agasc/supplement/magnitudes/mag_estimate.py
Original file line number Diff line number Diff line change
Expand Up @@ -550,7 +550,13 @@ def get_telemetry_by_observations(observations, ignore_exceptions=False, as_tabl
raise

if telem:
return vstack([Table(tel) for tel in telem]) if as_table else telem
# we ignore entries with error_code here because this means ignore_exceptions is true
# (otherwise the exception would have been raised already)
return (
vstack([Table(tel) for tel in telem if "error_code" not in tel])
if as_table
else telem
)

return []

Expand Down Expand Up @@ -584,7 +590,8 @@ def add_obs_info(telem, obs_stats):
o = telem["obsid"] == obsid
telem["obs_ok"][o] = np.ones(np.count_nonzero(o), dtype=bool) * s["obs_ok"]
if (
np.any(telem["mag_est_ok"][o])
len(telem[o]) > 0
and np.any(telem["mag_est_ok"][o])
and s["f_mag_est_ok"] > 0
and np.isfinite(s["q75"])
and np.isfinite(s["q25"])
Expand Down Expand Up @@ -1238,17 +1245,22 @@ def get_agasc_id_stats(
star_obs, obs_status_override=obs_status_override, telem=all_telem
)

# combine magnitude estimates using a weighted mean
weighted_mean = get_weighted_mean(stats)
stats["w"] = weighted_mean["weights"]
stats["mean_corrected"] = weighted_mean["mean_corrected"]
stats["weighted_mean"] = weighted_mean["weighted_mean"]
# we do this in this method because a weighted mean doesn't make much sense for a list
# of observations. It only makes sense when considering all observations of a star.
if len(stats) > 0:
# combine magnitude estimates using a weighted mean
weighted_mean = get_weighted_mean(stats)
stats["w"] = weighted_mean["weights"]
stats["mean_corrected"] = weighted_mean["mean_corrected"]
stats["weighted_mean"] = weighted_mean["weighted_mean"]
else:
# or make sure column exists
stats["w"] = np.array([])
stats["mean_corrected"] = np.array([])
stats["weighted_mean"] = np.array([])

star = get_star(agasc_id, use_supplement=False)

# still need to check that this is the same as before
last_obs_time = CxoTime(stats["mp_starcat_time"][-1]).cxcsec

logger.debug(" identifying outlying observations...")
for s, t in zip(stats, all_telem, strict=True):
if s["no_telem"] or s["excluded"]:
Expand Down Expand Up @@ -1333,7 +1345,7 @@ def get_agasc_id_stats(
result.update(
{
"color": star["COLOR1"],
"last_obs_time": last_obs_time,
"last_obs_time": CxoTime(stats["mp_starcat_time"][-1]).cxcsec,
"mag_aca": star["MAG_ACA"],
"mag_aca_err": star["MAG_ACA_ERR"] / 100,
"mag_obs_err": min_mag_obs_err,
Expand Down
Loading