Full Day RFI Flagging¶
by Josh Dillon, last updated June 19, 2023
This notebook is designed to figure out a single full-day RFI mask using the best autocorelations, taking individual file_calibration notebook results as a prior but then potentially undoing flags.
Here's a set of links to skip to particular figures and tables:
• Figure 1: Show All DPSS Residual z-Scores¶
• Figure 2: z-Score of DPSS-Filtered, Averaged Good Autocorrelation and Initial Flags¶
• Figure 3: z-Score of DPSS-Filtered, Averaged Good Autocorrelation and Expanded Flags¶
• Figure 4: z-Score of DPSS-Filtered, Averaged Good Autocorrelation and Final, Re-Computed Flags¶
• Figure 5: Summary of Flags Before and After Recomputing Them¶
In [1]:
import time
tstart = time.time()
In [2]:
import os
os.environ['HDF5_USE_FILE_LOCKING'] = 'FALSE'
import h5py
import hdf5plugin # REQUIRED to have the compression plugins available
import numpy as np
import pandas as pd
import glob
import os
import matplotlib.pyplot as plt
import matplotlib
import copy
import warnings
import textwrap
from pyuvdata import UVFlag, UVData, UVCal
from hera_cal import io, utils, abscal
from hera_cal.smooth_cal import CalibrationSmoother, dpss_filters, solve_2D_DPSS
from hera_qm import ant_class, xrfi, metrics_io
from hera_filters import dspec
from IPython.display import display, HTML
%matplotlib inline
display(HTML("<style>.container { width:100% !important; }</style>"))
_ = np.seterr(all='ignore') # get rid of red warnings
%config InlineBackend.figure_format = 'retina'
Parse inputs¶
In [3]:
# get filenames
SUM_FILE = os.environ.get("SUM_FILE", None)
# SUM_FILE = '/lustre/aoc/projects/hera/h6c-analysis/IDR2/2459866/zen.2459866.25282.sum.uvh5' # If sum_file is not defined in the environment variables, define it here.
SUM_SUFFIX = os.environ.get("SUM_SUFFIX", 'sum.uvh5')
SUM_AUTOS_SUFFIX = os.environ.get("SUM_AUTOS_SUFFIX", 'sum.autos.uvh5')
DIFF_AUTOS_SUFFIX = os.environ.get("DIFF_AUTOS_SUFFIX", 'diff.autos.uvh5')
CAL_SUFFIX = os.environ.get("CAL_SUFFIX", 'sum.omni.calfits')
ANT_CLASS_SUFFIX = os.environ.get("ANT_CLASS_SUFFIX", 'sum.ant_class.csv')
APRIORI_YAML_PATH = os.environ.get("APRIORI_YAML_PATH", None)
OUT_FLAG_SUFFIX = os.environ.get("OUT_FLAG_SUFFIX", 'sum.flag_waterfall.h5')
sum_glob = '.'.join(SUM_FILE.split('.')[:-3]) + '.*.' + SUM_SUFFIX
auto_sums_glob = sum_glob.replace(SUM_SUFFIX, SUM_AUTOS_SUFFIX)
auto_diffs_glob = sum_glob.replace(SUM_SUFFIX, DIFF_AUTOS_SUFFIX)
cal_files_glob = sum_glob.replace(SUM_SUFFIX, CAL_SUFFIX)
ant_class_csvs_glob = sum_glob.replace(SUM_SUFFIX, ANT_CLASS_SUFFIX)
In [4]:
# A priori flag settings
FM_LOW_FREQ = float(os.environ.get("FM_LOW_FREQ", 87.5)) # in MHz
FM_HIGH_FREQ = float(os.environ.get("FM_HIGH_FREQ", 108.0)) # in MHz
FM_freq_range = [FM_LOW_FREQ * 1e6, FM_HIGH_FREQ * 1e6]
MAX_SOLAR_ALT = float(os.environ.get("MAX_SOLAR_ALT", 0.0)) # in degrees
# DPSS settings
FREQ_FILTER_SCALE = float(os.environ.get("FREQ_FILTER_SCALE", 5.0)) # in MHz
TIME_FILTER_SCALE = float(os.environ.get("TIME_FILTER_SCALE", 450.0))# in s
EIGENVAL_CUTOFF = float(os.environ.get("EIGENVAL_CUTOFF", 1e-12))
# Outlier flagging settings
MIN_FRAC_OF_AUTOS = float(os.environ.get("MIN_FRAC_OF_AUTOS", .25))
MAX_AUTO_L2 = float(os.environ.get("MAX_AUTRO_L2", 1.2))
Z_THRESH = float(os.environ.get("Z_THRESH", 5.0))
WS_Z_THRESH = float(os.environ.get("WS_Z_THRESH", 4.0))
AVG_Z_THRESH = float(os.environ.get("AVG_Z_THRESH", 1.5))
REPEAT_FLAG_Z_THRESH = float(os.environ.get("REPEAT_FLAG_Z_THESH", 0.0))
MAX_FREQ_FLAG_FRAC = float(os.environ.get("MAX_FREQ_FLAG_FRAC", .25))
MAX_TIME_FLAG_FRAC = float(os.environ.get("MAX_TIME_FLAG_FRAC", .1))
for setting in ['FM_LOW_FREQ', 'FM_HIGH_FREQ', 'MAX_SOLAR_ALT', 'FREQ_FILTER_SCALE', 'TIME_FILTER_SCALE',
'EIGENVAL_CUTOFF', 'MIN_FRAC_OF_AUTOS', 'MAX_AUTO_L2', 'Z_THRESH', 'WS_Z_THRESH', 'AVG_Z_THRESH', 'REPEAT_FLAG_Z_THRESH',
'MAX_FREQ_FLAG_FRAC ', 'MAX_TIME_FLAG_FRAC ']:
print(f'{setting} = {eval(setting)}')
FM_LOW_FREQ = 87.5 FM_HIGH_FREQ = 108.0 MAX_SOLAR_ALT = 0.0 FREQ_FILTER_SCALE = 5.0 TIME_FILTER_SCALE = 450.0 EIGENVAL_CUTOFF = 1e-12 MIN_FRAC_OF_AUTOS = 0.25 MAX_AUTO_L2 = 1.2 Z_THRESH = 5.0 WS_Z_THRESH = 4.0 AVG_Z_THRESH = 1.5 REPEAT_FLAG_Z_THRESH = 2.0 MAX_FREQ_FLAG_FRAC = 0.25 MAX_TIME_FLAG_FRAC = 0.1
Load Data¶
In [5]:
auto_sums = sorted(glob.glob(auto_sums_glob))
print(f'Found {len(auto_sums)} *.{SUM_AUTOS_SUFFIX} files starting with {auto_sums[0]}.')
auto_diffs = sorted(glob.glob(auto_diffs_glob))
print(f'Found {len(auto_diffs)} *.{DIFF_AUTOS_SUFFIX} files starting with {auto_diffs[0]}.')
cal_files = sorted(glob.glob(cal_files_glob))
print(f'Found {len(cal_files)} *.{CAL_SUFFIX} files starting with {cal_files[0]}.')
ant_class_csvs = sorted(glob.glob(ant_class_csvs_glob))
print(f'Found {len(ant_class_csvs)} *.{ANT_CLASS_SUFFIX} files starting with {ant_class_csvs[0]}.')
Found 1571 *.sum.autos.uvh5 files starting with /mnt/sn1/data2/2460552/zen.2460552.16914.sum.autos.uvh5. Found 1571 *.diff.autos.uvh5 files starting with /mnt/sn1/data2/2460552/zen.2460552.16914.diff.autos.uvh5. Found 1571 *.sum.omni.calfits files starting with /mnt/sn1/data2/2460552/zen.2460552.16914.sum.omni.calfits. Found 1571 *.sum.ant_class.csv files starting with /mnt/sn1/data2/2460552/zen.2460552.16914.sum.ant_class.csv.
In [6]:
# Load ant_class csvs
tables = [pd.read_csv(f).dropna(axis=0, how='all') for f in ant_class_csvs]
table_cols = tables[0].columns[1::2]
class_cols = tables[0].columns[2::2]
In [7]:
# Figure out antennas that were not flagged when the sun was down, or were only flagged for Even/Odd Zeros or Redcal chi^2 or Bad X-Engine Diffs
ap_strs = np.array(tables[0]['Antenna'])
ant_flags = np.array([t[class_cols] for t in tables]) == 'bad'
sun_low_enough = np.array([t['Solar Alt'] < MAX_SOLAR_ALT for t in tables])
ants = sorted(set(int(a[:-1]) for a in ap_strs))
candidate_autos = set()
for i, ap_str in enumerate(ap_strs):
has_other_flags = np.any([(ant_flags[:, i, cc] & sun_low_enough[:, i]) for cc, colname in enumerate(class_cols)
if colname not in ['Antenna Class', 'Even/Odd Zeros Class','Redcal chi^2 Class', 'Bad Diff X-Engines Class']])
if not has_other_flags:
ap = int(ap_str[:-1]), utils.comply_pol(ap_str[-1])
candidate_autos.add(utils.join_bl(ap, ap))
In [8]:
# Load sum and diff autos, checking to see whether any of them show packet loss
good_data = {}
info_dicts = {}
for sf, df in list(zip(auto_sums, auto_diffs)):
rv = io.read_hera_hdf5(sf, bls=candidate_autos)
good_data[sf] = rv['data']
info_dicts[sf] = rv['info']
diff = io.read_hera_hdf5(df, bls=candidate_autos)['data']
zeros_class = ant_class.even_odd_zeros_checker(good_data[sf], diff)
for ant in zeros_class.bad_ants:
candidate_autos.remove(utils.join_bl(ant, ant))
In [9]:
# load calibration solutions
cs = CalibrationSmoother(cal_files, load_cspa=False, load_chisq=False, pick_refant=False)
In [10]:
# load a priori flagged times
if APRIORI_YAML_PATH is not None:
print(f'Loading a priori flagged times from {APRIORI_YAML_PATH}')
apriori_flags = np.zeros(len(cs.time_grid), dtype=bool)
apriori_flags[metrics_io.read_a_priori_int_flags(APRIORI_YAML_PATH, times=cs.time_grid).astype(int)] = True
Loading a priori flagged times from /mnt/sn1/data2/2460552/2460552_apriori_flags.yaml
Figure out a subset of most-stable antennas to filter and flag on¶
In [11]:
initial_cal_flags = np.all([f for f in cs.flag_grids.values()], axis=0)
In [12]:
def average_autos(per_file_autos, bls_to_use, auto_sums, cs):
'''Averages autos over baselines, matching the time_grid in CalibrationSmoother cs.'''
avg_per_file_autos = {sf: np.mean([per_file_autos[sf][bl] for bl in bls_to_use], axis=0) for sf in auto_sums}
avg_autos = np.zeros((len(cs.time_grid), len(cs.freqs)), dtype=float)
for sf, cf in zip(auto_sums, cs.cals):
avg_autos[cs.time_indices[cf], :] = np.abs(avg_per_file_autos[sf])
return avg_autos
In [13]:
avg_candidate_auto = average_autos(good_data, candidate_autos, auto_sums, cs)
In [14]:
def flag_FM(flags, freqs, freq_range=[87.5e6, 108e6]):
'''Apply flags to all frequencies within freq_range (in Hz).'''
flags[:, np.logical_and(freqs >= freq_range[0], freqs <= freq_range[1])] = True
In [15]:
flag_FM(initial_cal_flags, cs.freqs, freq_range=FM_freq_range)
In [16]:
def flag_sun(flags, times, max_solar_alt=0):
'''Apply flags to all times where the solar altitude is greater than max_solar_alt (in degrees).'''
solar_altitudes_degrees = utils.get_sun_alt(times)
flags[solar_altitudes_degrees >= max_solar_alt, :] = True
In [17]:
flag_sun(initial_cal_flags, cs.time_grid, max_solar_alt=MAX_SOLAR_ALT)
In [18]:
if APRIORI_YAML_PATH is not None:
initial_cal_flags[apriori_flags, :] = True
In [19]:
def predict_auto_noise(auto, dt, df, nsamples=1):
'''Predict noise on an (antenna-averaged) autocorrelation. The product of Delta t and Delta f
must be unitless. For N autocorrelations averaged together, use nsamples=N.'''
int_count = int(dt * df) * nsamples
return np.abs(auto) / np.sqrt(int_count / 2)
In [20]:
# Figure out noise and weights
int_time = 24 * 3600 * np.median(np.diff(cs.time_grid))
chan_res = np.median(np.diff(cs.freqs))
noise = predict_auto_noise(avg_candidate_auto, int_time, chan_res, nsamples=1)
wgts = np.where(initial_cal_flags, 0, noise**-2)
In [21]:
# get slices to index into region of waterfall outwide of which it's 100% flagged
unflagged_ints = np.squeeze(np.argwhere(~np.all(initial_cal_flags, axis=1)))
ints_to_filt = slice(unflagged_ints[0], unflagged_ints[-1] + 1)
unflagged_chans = np.squeeze(np.argwhere(~np.all(initial_cal_flags, axis=0)))
chans_to_filt = slice(unflagged_chans[0], unflagged_chans[-1] + 1)
In [22]:
# Filter every autocorrelation individually
cached_output = {}
models = {}
sqrt_mean_sqs = {}
time_filters, freq_filters = dpss_filters(freqs=cs.freqs[chans_to_filt], # Hz
times=cs.time_grid[ints_to_filt], # JD
freq_scale=FREQ_FILTER_SCALE,
time_scale=TIME_FILTER_SCALE,
eigenval_cutoff=EIGENVAL_CUTOFF)
for bl in candidate_autos:
auto_here = average_autos(good_data, [bl], auto_sums, cs)
models[bl] = np.array(auto_here)
model, cached_output = solve_2D_DPSS(auto_here[ints_to_filt, chans_to_filt], wgts[ints_to_filt, chans_to_filt],
time_filters, freq_filters, method='lu_solve', cached_input=cached_output)
models[bl][ints_to_filt, chans_to_filt] = model
noise_model = predict_auto_noise(models[bl], int_time, chan_res, nsamples=1)
sqrt_mean_sqs[bl] = np.nanmean(np.where(initial_cal_flags, np.nan, (auto_here - models[bl]) / noise_model)**2)**.5
An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.
In [23]:
# Pick best autocorrelations to filter on
L2_bound = max(np.quantile(list(sqrt_mean_sqs.values()), MIN_FRAC_OF_AUTOS), MAX_AUTO_L2)
good_auto_bls = [bl for bl in candidate_autos if sqrt_mean_sqs[bl] <= L2_bound]
print(f'Using {len(good_auto_bls)} out of {len(candidate_autos)} candidate autocorrelations ({len(good_auto_bls) / len(candidate_autos):.2%}).')
Using 63 out of 113 candidate autocorrelations (55.75%).
In [24]:
extent = [cs.freqs[0]/1e6, cs.freqs[-1]/1e6, cs.time_grid[-1] - int(cs.time_grid[0]), cs.time_grid[0] - int(cs.time_grid[0])]
In [25]:
def plot_all_filtered_bls(N_per_row=8):
N_rows = int(np.ceil(len(candidate_autos) / N_per_row))
fig, axes = plt.subplots(N_rows, N_per_row, figsize=(14, 3 * N_rows), dpi=100,
sharex=True, sharey=True, gridspec_kw={'wspace': 0, 'hspace': .18})
for i, (ax, bl) in enumerate(zip(axes.flatten(), sorted(sqrt_mean_sqs.keys(), key=lambda bl: sqrt_mean_sqs[bl]))):
auto_here = average_autos(good_data, [bl], auto_sums, cs)
noise_model = predict_auto_noise(models[bl], int_time, chan_res, nsamples=1)
im = ax.imshow(np.where(initial_cal_flags, np.nan, (auto_here - models[bl]) / noise_model).real,
aspect='auto', interpolation='none', cmap='bwr', vmin=-10, vmax=10, extent=extent)
ax.set_title(f'{bl[0]}{bl[2][0]}: {sqrt_mean_sqs[bl]:.3}', color=('k' if sqrt_mean_sqs[bl] <= L2_bound else 'r'), fontsize=10)
if i == 0:
plt.colorbar(im, ax=axes, location='top', label=r'Autocorrelation z-score after DPSS filtering (with $\langle z^2 \rangle^{1/2}$)', extend='both', aspect=40, pad=.015)
if i % N_per_row == 0:
ax.set_ylabel(f'JD - {int(cs.time_grid[0])}')
for ax in axes[-1, :]:
ax.set_xlabel('Frequency (MHz)')
plt.tight_layout()
plt.show()
antpols = [(int(ap[:-1]), utils.comply_pol(ap[-1])) for ap in ap_strs]
other_autos = [f'{ap[0]}{ap[-1][-1]}' for ap in antpols if utils.join_bl(ap, ap) not in candidate_autos]
print('Not plotted here due to prior antenna flagging:')
print('\t' + '\n\t'.join(textwrap.wrap(', '.join(other_autos), 80, break_long_words=False)))
Figure 1: Show All DPSS Residual z-Scores¶
This figure shows the z-score waterfall of each antenna. Also shown is the square root of the mean of the square of each waterfall, as a metric of its instability. Antennas in red are excluded from the average of most stable antennas that are used for subsequent flagging.
In [26]:
plot_all_filtered_bls()
This figure includes Axes that are not compatible with tight_layout, so results might be incorrect.