by Josh Dillon, last updated June 19, 2023
This notebook is designed to figure out a single full-day RFI mask using the best autocorelations, taking individual file_calibration notebook results as a prior but then potentially undoing flags.
Here's a set of links to skip to particular figures and tables:
import time
tstart = time.time()
import os
os.environ['HDF5_USE_FILE_LOCKING'] = 'FALSE'
import h5py
import hdf5plugin # REQUIRED to have the compression plugins available
import numpy as np
import pandas as pd
import glob
import os
import matplotlib.pyplot as plt
import matplotlib
import copy
import warnings
from pyuvdata import UVFlag, UVData, UVCal
from hera_cal import io, utils, abscal
from hera_cal.smooth_cal import CalibrationSmoother, dpss_filters, solve_2D_DPSS
from hera_qm import ant_class, xrfi, metrics_io
from hera_filters import dspec
from IPython.display import display, HTML
%matplotlib inline
display(HTML("<style>.container { width:100% !important; }</style>"))
_ = np.seterr(all='ignore') # get rid of red warnings
%config InlineBackend.figure_format = 'retina'
# get filenames
SUM_FILE = os.environ.get("SUM_FILE", None)
# SUM_FILE = '/users/jsdillon/lustre/H6C/abscal/2459853/zen.2459853.25518.sum.uvh5' # If sum_file is not defined in the environment variables, define it here.
SUM_SUFFIX = os.environ.get("SUM_SUFFIX", 'sum.uvh5')
SUM_AUTOS_SUFFIX = os.environ.get("SUM_AUTOS_SUFFIX", 'sum.autos.uvh5')
DIFF_AUTOS_SUFFIX = os.environ.get("DIFF_AUTOS_SUFFIX", 'diff.autos.uvh5')
CAL_SUFFIX = os.environ.get("CAL_SUFFIX", 'sum.omni.calfits')
ANT_CLASS_SUFFIX = os.environ.get("ANT_CLASS_SUFFIX", 'sum.ant_class.csv')
APRIORI_YAML_PATH = os.environ.get("APRIORI_YAML_PATH", None)
OUT_FLAG_SUFFIX = os.environ.get("OUT_FLAG_SUFFIX", 'sum.flag_waterfall.h5')
sum_glob = '.'.join(SUM_FILE.split('.')[:-3]) + '.*.' + SUM_SUFFIX
auto_sums_glob = sum_glob.replace(SUM_SUFFIX, SUM_AUTOS_SUFFIX)
auto_diffs_glob = sum_glob.replace(SUM_SUFFIX, DIFF_AUTOS_SUFFIX)
cal_files_glob = sum_glob.replace(SUM_SUFFIX, CAL_SUFFIX)
ant_class_csvs_glob = sum_glob.replace(SUM_SUFFIX, ANT_CLASS_SUFFIX)
# A priori flag settings
FM_LOW_FREQ = float(os.environ.get("FM_LOW_FREQ", 87.5)) # in MHz
FM_HIGH_FREQ = float(os.environ.get("FM_HIGH_FREQ", 108.0)) # in MHz
FM_freq_range = [FM_LOW_FREQ * 1e6, FM_HIGH_FREQ * 1e6]
MAX_SOLAR_ALT = float(os.environ.get("MAX_SOLAR_ALT", 0.0)) # in degrees
# DPSS settings
FREQ_FILTER_SCALE = float(os.environ.get("FREQ_FILTER_SCALE", 5.0)) # in MHz
TIME_FILTER_SCALE = float(os.environ.get("TIME_FILTER_SCALE", 450.0))# in s
EIGENVAL_CUTOFF = float(os.environ.get("EIGENVAL_CUTOFF", 1e-12))
# Outlier flagging settings
MIN_FRAC_OF_AUTOS = float(os.environ.get("MIN_FRAC_OF_AUTOS", .25))
MAX_AUTO_L2 = float(os.environ.get("MAX_AUTRO_L2", 1.2))
Z_THRESH = float(os.environ.get("Z_THRESH", 5.0))
WS_Z_THRESH = float(os.environ.get("WS_Z_THRESH", 4.0))
AVG_Z_THRESH = float(os.environ.get("AVG_Z_THRESH", 1.5))
REPEAT_FLAG_Z_THRESH = float(os.environ.get("REPEAT_FLAG_Z_THESH", 2.0))
MAX_FREQ_FLAG_FRAC = float(os.environ.get("MAX_FREQ_FLAG_FRAC", .25))
MAX_TIME_FLAG_FRAC = float(os.environ.get("MAX_TIME_FLAG_FRAC", .1))
for setting in ['FM_LOW_FREQ', 'FM_HIGH_FREQ', 'MAX_SOLAR_ALT', 'FREQ_FILTER_SCALE', 'TIME_FILTER_SCALE',
'EIGENVAL_CUTOFF', 'MIN_FRAC_OF_AUTOS', 'MAX_AUTO_L2', 'Z_THRESH', 'WS_Z_THRESH', 'AVG_Z_THRESH', 'REPEAT_FLAG_Z_THRESH',
'MAX_FREQ_FLAG_FRAC ', 'MAX_TIME_FLAG_FRAC ']:
print(f'{setting} = {eval(setting)}')
FM_LOW_FREQ = 87.5 FM_HIGH_FREQ = 108.0 MAX_SOLAR_ALT = 0.0 FREQ_FILTER_SCALE = 5.0 TIME_FILTER_SCALE = 450.0 EIGENVAL_CUTOFF = 1e-12 MIN_FRAC_OF_AUTOS = 0.25 MAX_AUTO_L2 = 1.2 Z_THRESH = 5.0 WS_Z_THRESH = 4.0 AVG_Z_THRESH = 1.5 REPEAT_FLAG_Z_THRESH = 2.0 MAX_FREQ_FLAG_FRAC = 0.25 MAX_TIME_FLAG_FRAC = 0.1
auto_sums = sorted(glob.glob(auto_sums_glob))
print(f'Found {len(auto_sums)} *.{SUM_AUTOS_SUFFIX} files starting with {auto_sums[0]}.')
auto_diffs = sorted(glob.glob(auto_diffs_glob))
print(f'Found {len(auto_diffs)} *.{DIFF_AUTOS_SUFFIX} files starting with {auto_diffs[0]}.')
cal_files = sorted(glob.glob(cal_files_glob))
print(f'Found {len(cal_files)} *.{CAL_SUFFIX} files starting with {cal_files[0]}.')
ant_class_csvs = sorted(glob.glob(ant_class_csvs_glob))
print(f'Found {len(ant_class_csvs)} *.{ANT_CLASS_SUFFIX} files starting with {ant_class_csvs[0]}.')
Found 361 *.sum.autos.uvh5 files starting with /mnt/sn1/2460152/zen.2460152.42109.sum.autos.uvh5. Found 361 *.diff.autos.uvh5 files starting with /mnt/sn1/2460152/zen.2460152.42109.diff.autos.uvh5. Found 361 *.sum.omni.calfits files starting with /mnt/sn1/2460152/zen.2460152.42109.sum.omni.calfits. Found 361 *.sum.ant_class.csv files starting with /mnt/sn1/2460152/zen.2460152.42109.sum.ant_class.csv.
# Load ant_class csvs
tables = [pd.read_csv(f).dropna(axis=0, how='all') for f in ant_class_csvs]
table_cols = tables[0].columns[1::2]
class_cols = tables[0].columns[2::2]
# Figure out antennas that were not flagged when the sun was down, or were only flagged for Even/Odd Zeros or Redcal chi^2 or Bad X-Engine Diffs
ap_strs = np.array(tables[0]['Antenna'])
ant_flags = np.array([t[class_cols] for t in tables]) == 'bad'
sun_low_enough = np.array([t['Solar Alt'] < MAX_SOLAR_ALT for t in tables])
ants = sorted(set(int(a[:-1]) for a in ap_strs))
candidate_autos = set()
for i, ap_str in enumerate(ap_strs):
has_other_flags = np.any([(ant_flags[:, i, cc] & sun_low_enough[:, i]) for cc, colname in enumerate(class_cols)
if colname not in ['Antenna Class', 'Even/Odd Zeros Class','Redcal chi^2 Class', 'Bad Diff X-Engines Class']])
if not has_other_flags:
ap = int(ap_str[:-1]), utils.comply_pol(ap_str[-1])
candidate_autos.add(utils.join_bl(ap, ap))
# Load sum and diff autos, checking to see whether any of them show packet loss
good_data = {}
info_dicts = {}
for sf, df in list(zip(auto_sums, auto_diffs)):
rv = io.read_hera_hdf5(sf, bls=candidate_autos)
good_data[sf] = rv['data']
info_dicts[sf] = rv['info']
diff = io.read_hera_hdf5(df, bls=candidate_autos)['data']
zeros_class = ant_class.even_odd_zeros_checker(good_data[sf], diff)
for ant in zeros_class.bad_ants:
candidate_autos.remove(utils.join_bl(ant, ant))
# load calibration solutions
cs = CalibrationSmoother(cal_files, load_cspa=False, load_chisq=False, pick_refant=False)
# load a priori flagged times
if APRIORI_YAML_PATH is not None:
print(f'Loading a priori flagged times from {APRIORI_YAML_PATH}')
apriori_flags = np.zeros(len(cs.time_grid), dtype=bool)
apriori_flags[metrics_io.read_a_priori_int_flags(APRIORI_YAML_PATH, times=cs.time_grid).astype(int)] = True
Loading a priori flagged times from /mnt/sn1/2460152/2460152_apriori_flags.yaml
initial_cal_flags = np.all([f for f in cs.flag_grids.values()], axis=0)
def average_autos(per_file_autos, bls_to_use, auto_sums, cs):
'''Averages autos over baselines, matching the time_grid in CalibrationSmoother cs.'''
avg_per_file_autos = {sf: np.mean([per_file_autos[sf][bl] for bl in bls_to_use], axis=0) for sf in auto_sums}
avg_autos = np.zeros((len(cs.time_grid), len(cs.freqs)), dtype=float)
for sf, cf in zip(auto_sums, cs.cals):
avg_autos[cs.time_indices[cf], :] = np.abs(avg_per_file_autos[sf])
return avg_autos
avg_candidate_auto = average_autos(good_data, candidate_autos, auto_sums, cs)
def flag_FM(flags, freqs, freq_range=[87.5e6, 108e6]):
'''Apply flags to all frequencies within freq_range (in Hz).'''
flags[:, np.logical_and(freqs >= freq_range[0], freqs <= freq_range[1])] = True
flag_FM(initial_cal_flags, cs.freqs, freq_range=FM_freq_range)
def flag_sun(flags, times, max_solar_alt=0):
'''Apply flags to all times where the solar altitude is greater than max_solar_alt (in degrees).'''
solar_altitudes_degrees = utils.get_sun_alt(times)
flags[solar_altitudes_degrees >= max_solar_alt, :] = True
flag_sun(initial_cal_flags, cs.time_grid, max_solar_alt=MAX_SOLAR_ALT)
if APRIORI_YAML_PATH is not None:
initial_cal_flags[apriori_flags, :] = True
def predict_auto_noise(auto, dt, df, nsamples=1):
'''Predict noise on an (antenna-averaged) autocorrelation. The product of Delta t and Delta f
must be unitless. For N autocorrelations averaged together, use nsamples=N.'''
int_count = int(dt * df) * nsamples
return np.abs(auto) / np.sqrt(int_count / 2)
# Figure out noise and weights
int_time = 24 * 3600 * np.median(np.diff(cs.time_grid))
chan_res = np.median(np.diff(cs.freqs))
noise = predict_auto_noise(avg_candidate_auto, int_time, chan_res, nsamples=1)
wgts = np.where(initial_cal_flags, 0, noise**-2)
# get slices to index into region of waterfall outwide of which it's 100% flagged
unflagged_ints = np.squeeze(np.argwhere(~np.all(initial_cal_flags, axis=1)))
ints_to_filt = slice(unflagged_ints[0], unflagged_ints[-1] + 1)
unflagged_chans = np.squeeze(np.argwhere(~np.all(initial_cal_flags, axis=0)))
chans_to_filt = slice(unflagged_chans[0], unflagged_chans[-1] + 1)
# Filter every autocorrelation individually
cached_output = {}
models = {}
sqrt_mean_sqs = {}
time_filters, freq_filters = dpss_filters(freqs=cs.freqs[chans_to_filt], # Hz
times=cs.time_grid[ints_to_filt], # JD
freq_scale=FREQ_FILTER_SCALE,
time_scale=TIME_FILTER_SCALE,
eigenval_cutoff=EIGENVAL_CUTOFF)
for bl in candidate_autos:
auto_here = average_autos(good_data, [bl], auto_sums, cs)
models[bl] = np.array(auto_here)
model, cached_output = solve_2D_DPSS(auto_here[ints_to_filt, chans_to_filt], wgts[ints_to_filt, chans_to_filt],
time_filters, freq_filters, method='lu_solve', cached_input=cached_output)
models[bl][ints_to_filt, chans_to_filt] = model
noise_model = predict_auto_noise(models[bl], int_time, chan_res, nsamples=1)
sqrt_mean_sqs[bl] = np.nanmean(np.where(initial_cal_flags, np.nan, (auto_here - models[bl]) / noise_model)**2)**.5
WARNING:jax._src.lib.xla_bridge:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
# Pick best autocorrelations to filter on
L2_bound = max(np.quantile(list(sqrt_mean_sqs.values()), MIN_FRAC_OF_AUTOS), MAX_AUTO_L2)
good_auto_bls = [bl for bl in candidate_autos if sqrt_mean_sqs[bl] <= L2_bound]
print(f'Using {len(good_auto_bls)} out of {len(candidate_autos)} candidate autocorrelations ({len(good_auto_bls) / len(candidate_autos):.2%}).')
Using 78 out of 122 candidate autocorrelations (63.93%).
extent = [cs.freqs[0]/1e6, cs.freqs[-1]/1e6, cs.time_grid[-1] - int(cs.time_grid[0]), cs.time_grid[0] - int(cs.time_grid[0])]
def plot_all_filtered_bls(N_per_row=8):
N_rows = int(np.ceil(len(candidate_autos) / N_per_row))
fig, axes = plt.subplots(N_rows, N_per_row, figsize=(14, 3 * N_rows), dpi=100,
sharex=True, sharey=True, gridspec_kw={'wspace': 0, 'hspace': .18})
for i, (ax, bl) in enumerate(zip(axes.flatten(), sorted(sqrt_mean_sqs.keys(), key=lambda bl: sqrt_mean_sqs[bl]))):
auto_here = average_autos(good_data, [bl], auto_sums, cs)
noise_model = predict_auto_noise(models[bl], int_time, chan_res, nsamples=1)
im = ax.imshow(np.where(initial_cal_flags, np.nan, (auto_here - models[bl]) / noise_model).real,
aspect='auto', interpolation='none', cmap='bwr', vmin=-10, vmax=10, extent=extent)
ax.set_title(f'{bl[0]}{bl[2][0]}: {sqrt_mean_sqs[bl]:.3}', color=('k' if sqrt_mean_sqs[bl] <= L2_bound else 'r'), fontsize=10)
if i == 0:
plt.colorbar(im, ax=axes, location='top', label=r'Autocorrelation z-score after DPSS filtering (with $\langle z^2 \rangle^{1/2}$)', extend='both', aspect=40, pad=.015)
if i % N_per_row == 0:
ax.set_ylabel(f'JD - {int(cs.time_grid[0])}')
for ax in axes[-1, :]:
ax.set_xlabel('Frequency (MHz)')
plt.tight_layout()
This figure shows the z-score waterfall of each antenna. Also shown is the square root of the mean of the square of each waterfall, as a metric of its instability. Antennas in red are excluded from the average of most stable antennas that are used for subsequent flagging.
plot_all_filtered_bls()
This figure includes Axes that are not compatible with tight_layout, so results might be incorrect.