diff --git a/sotodlib/core/axisman.py b/sotodlib/core/axisman.py index 06792a451..92f8960b4 100644 --- a/sotodlib/core/axisman.py +++ b/sotodlib/core/axisman.py @@ -870,25 +870,78 @@ def reindex_axis(self, axis, indexes, in_place=True): else: # By this point we have a non AxisManager # assignment assigned to only our axis. - # Build new array with the correct indexes. - shape = [len(indexes)] - if isinstance(v, np.ndarray): - for s in np.shape(v)[1:]: - shape.append(s) - - new_v = np.empty(shape, dtype=v.dtype) - if isinstance(v.dtype, float): - # Fill any float arrays with nans - # Non float arrays may have weird - # behavior for newly added indexes. - # Oh well. - new_v *= np.nan - - for i, index in enumerate(indexes): - if np.isnan(index) or not (0 <= index < len(v)): - continue - new_v[i] = v[int(index)] + from so3g.proj import RangesMatrix, Ranges + import scipy.sparse as sp + + if isinstance(v, RangesMatrix): + # RangesMatrix has no dtype; build a new one by selecting + # ranges, using empty Ranges for missing indices. + if len(v.shape) >= 2: + nsamps = v.shape[-1] + elif len(v.ranges) > 0 and hasattr(v.ranges[0], 'count'): + nsamps = v.ranges[0].count + else: + nsamps = 0 + + new_ranges = [] + for index in indexes: + if (isinstance(index, float) and np.isnan(index)) \ + or not (0 <= int(index) < len(v.ranges)): + new_ranges.append(Ranges(nsamps)) + else: + new_ranges.append(v.ranges[int(index)]) + new_v = RangesMatrix(new_ranges) + + elif sp.issparse(v): + # Sparse matrix: select rows without densifying. + # Invalid indices become all-zero rows. + v_csr = v.tocsr() + n_rows_in = v_csr.shape[0] + valid_mask = np.array([ + not (isinstance(idx, float) and np.isnan(idx)) + and 0 <= int(idx) < n_rows_in + for idx in indexes + ]) + safe = np.array([ + int(idx) if valid else 0 + for idx, valid in zip(indexes, valid_mask) + ], dtype=np.int64) + selected = v_csr[safe] + if not valid_mask.all(): + diag = sp.diags(valid_mask.astype(v_csr.dtype)) + selected = (diag @ selected).tocsr() + new_v = selected + + else: + # ndarray (or anything else with len/shape/dtype/__getitem__) + shape = [len(indexes)] + if isinstance(v, np.ndarray): + for s in np.shape(v)[1:]: + shape.append(s) + elif hasattr(v, 'shape') and len(v.shape) > 1: + for s in v.shape[1:]: + shape.append(s) + + if v.dtype == float: + # Fill any float arrays with nans + # Non float arrays may have weird + # behavior for newly added indexes. + # Oh well. + new_v = np.full(shape, np.nan) + elif v.dtype == int: + # Fill with -1 values which generally + # correspond ot unassigned (e.g. bias lines) + new_v = np.full(shape, -1) + else: + new_v = np.zeros(shape, dtype=v.dtype) + + v_len = v.shape[0] if hasattr(v, 'shape') and len(v.shape) > 0 else len(v) + for i, index in enumerate(indexes): + if np.isnan(index) or not (0 <= index < v_len): + continue + + new_v[i] = v[int(index)] reindexed_vs[assignment] = new_v new_axes[assignment] = np.array(axes) diff --git a/sotodlib/core/context.py b/sotodlib/core/context.py index c0899247c..9687f6fba 100644 --- a/sotodlib/core/context.py +++ b/sotodlib/core/context.py @@ -356,21 +356,25 @@ def get_obs(self, # Both tones and det_info exist. else: - # Grab all band and channel info for dets + tdets + # Grab all stream, band, and channel info for dets + tdets + det_streams = aman.det_info.stream_id det_bands = aman.det_info.smurf.band det_channels = aman.det_info.smurf.channel + tdet_streams = aman.tones.stream_id tdet_bands = aman.tones.band tdet_channels = aman.tones.channel # Create a sorted array of dets + tdets - special_band_ch = [(b, c) for b, c in zip(tdet_bands, tdet_channels)] - normal_band_ch = [(b, c) for b, c in zip(det_bands, det_channels)] - band_ch = np.array(sorted(normal_band_ch + special_band_ch)) + special_band_ch = [(s, b, c) for s, b, c in zip(tdet_streams, tdet_bands, tdet_channels)] + normal_band_ch = [(s, b, c) for s, b, c in zip(det_streams, det_bands, det_channels)] + band_ch = sorted(normal_band_ch + special_band_ch) # Grab the det idxs from the det band + channels det_indexes = np.full(len(band_ch), np.nan) - for i, (b, c) in enumerate(band_ch): - w = np.where((det_bands == b) & (det_channels == c))[0] + for i, (s, b, c) in enumerate(band_ch): + w = np.where((det_streams == s) & \ + (det_bands == b) & \ + (det_channels == c))[0] if len(w) == 0: continue @@ -378,8 +382,10 @@ def get_obs(self, # Grab the tdet idxs from the tdet band + channels tdet_indexes = np.full(len(band_ch), np.nan) - for i, (b, c) in enumerate(band_ch): - w = np.where((tdet_bands == b) & (tdet_channels == c))[0] + for i, (s, b, c) in enumerate(band_ch): + w = np.where((tdet_streams == s) & \ + (tdet_bands == b) & \ + (tdet_channels == c))[0] if len(w) == 0: continue @@ -393,13 +399,20 @@ def get_obs(self, # Finally use the tdet idxs to fill in the tdet data # For the signal, band, and channels + + # If there is no signal, need to pre-populate the signal axis + if no_signal: + del aman['signal'] + aman.wrap_new('signal', ('dets', 'samps'), dtype='float32') for i, tidx in enumerate(tdet_indexes): if np.isnan(tidx): continue - + aman.signal[i] = aman.tones.signal[int(tidx)] + aman.det_info.stream_id[i] = aman.tones.stream_id[int(tidx)] aman.det_info.smurf.channel[i] = aman.tones.channel[int(tidx)] aman.det_info.smurf.band[i] = aman.tones.band[int(tidx)] + aman.det_info.wafer.type[i] = 'PROB' def add_tdet_ids(aman, tdet_indexes, tdet_ids): """ diff --git a/sotodlib/core/metadata/loader.py b/sotodlib/core/metadata/loader.py index 6147fcaba..26b812fd9 100644 --- a/sotodlib/core/metadata/loader.py +++ b/sotodlib/core/metadata/loader.py @@ -323,7 +323,10 @@ def load_one(self, spec, request, det_info): if len(results) == 1: result = results[0] else: - result = results[0].concatenate(results) + if isinstance(results[0], core.AxisManager): + result = results[0].concatenate(results, other_fields='first') + else: + result = results[0].concatenate(results) return result def load(self, spec_list, request, det_info=None, free_tags=[], diff --git a/sotodlib/preprocess/preprocess_util.py b/sotodlib/preprocess/preprocess_util.py index 796c4793c..c841f287a 100644 --- a/sotodlib/preprocess/preprocess_util.py +++ b/sotodlib/preprocess/preprocess_util.py @@ -587,6 +587,7 @@ def load_and_preprocess(obs_id, configs, context=None, dets=None, meta=None, pipe = Pipeline(configs["process_pipe"], logger=logger) aman = context.get_obs(meta, no_signal=no_signal) pipe.run(aman, aman.preprocess, select=False) + return aman, full_aman diff --git a/sotodlib/site_pipeline/make_ml_map.py b/sotodlib/site_pipeline/make_ml_map.py index 92f94e05c..2c4c3db08 100644 --- a/sotodlib/site_pipeline/make_ml_map.py +++ b/sotodlib/site_pipeline/make_ml_map.py @@ -43,6 +43,7 @@ def get_parser(parser=None): parser.add_argument( "--moon-mask", type=str, default="/global/cfs/cdirs/sobs/users/sigurdkn/masks/sidelobe/moon.fits", help="Location of Moon sidelobe mask") parser.add_argument("--hits", action="store_true", help="Write hits maps") parser.add_argument("--cut-type", type=str, default="full") + parser.add_argument( "--fixed-tones", type=int, default=0, help="0: don't include fixed tones in noise model. Nonzero: include fixed tones in noise model") return parser unit_defs = {"nK":1e-9, "uK":1e-6, "mK":1e-3, "K":1.0, "kK":1e3, "MK": 1e6, "GK":1e9} @@ -217,10 +218,130 @@ def main(**args): if args.max_dets is not None: meta.restrict('dets', meta['dets'].vals[:args.max_dets]) if len(my_dets) == 0: raise DataMissing("no dets left") + # Load fixed tones before context changes + if args.fixed_tones: + L.debug('Loading tones') + tones_aman = context.get_obs(meta.obs_info.obs_id, no_signal=True, + special_channels=True, + reindex_dets=True) + # Remove preprocessing -- it's irrelevant + if 'preprocess' in tones_aman._fields: + del tones_aman['preprocess'] + # Get relevant UFMs + keep = np.isin(tones_aman.det_info.stream_id, np.unique(meta.det_info.stream_id)) + tones_aman.restrict("dets", keep) + + # Get just tones + keep = tones_aman.det_info.wafer.type == 'PROB' + tones_aman.restrict("dets", keep) + + # Set calibrations + tones_aman.det_cal.phase_to_pW[tones_aman.det_info.wafer.type == 'PROB'] = np.nanmedian(meta.det_cal.phase_to_pW) + tones_aman.relcal.relcal[tones_aman.det_info.wafer.type == 'PROB'] = 1 + for k in tones_aman.abscal._fields: + tones_aman.abscal[k][tones_aman.det_info.wafer.type == 'PROB'] = np.nanmedian(meta.abscal[k]) + + # Detrend + tones_aman.signal -= np.median(tones_aman.signal, axis=-1)[..., None] + utils.deslope(tones_aman.signal, w=5, inplace=True) + + # Convert to K_CMB + tones_aman.signal *= tones_aman.det_cal.phase_to_pW[..., None] + tones_aman.signal *= tones_aman.abscal.abscal_cmb[..., None] + tones_aman.signal *= tones_aman.relcal.relcal[..., None] + # Actually read the data + L.debug('Loading obs') with bench.mark("read_obs %s" % sub_id): #obs = context.get_obs(sub_id, meta=meta) obs, _ = pp_util.load_and_preprocess(obs_id, preproc, context=context, meta=meta) + + # Add special channels + if args.fixed_tones: + # Helper functions for filling out to match amans + def zeros_like_aman(template, dets_axis): + """ + Build a new AxisManager mirroring `template`'s structure, with the dets + axis swapped for `dets_axis`. Array fields are zero-filled; RangesMatrix + fields are built as empty (no flagged samples). + """ + new_axes = [] + for ax_name, ax in template._axes.items(): + new_axes.append(dets_axis if ax_name == 'dets' else ax) + out = core.AxisManager(*new_axes) + + for name, field in template._fields.items(): + assignment = template._assignments[name] + + if isinstance(field, core.AxisManager): + out.wrap(name, zeros_like_aman(field, dets_axis)) + continue + + if assignment is None or all(a is None for a in assignment): + # Non-axis-aligned metadata — copy as-is + out.wrap(name, field) + continue + + if isinstance(field, so3g.proj.RangesMatrix): + shape = tuple( + (dets_axis.count if a == 'dets' else template._axes[a].count) + for a in assignment if a is not None + ) + rm = so3g.proj.RangesMatrix.zeros(shape) + wrap_axes = [(i, a) for i, a in enumerate(assignment) if a is not None] + out.wrap(name, rm, wrap_axes) + continue + + # Plain ndarray field + shape = [] + for dim_size, ax_name in zip(field.shape, assignment): + shape.append(dets_axis.count if ax_name == 'dets' else dim_size) + arr = np.zeros(shape, dtype=getattr(field, 'dtype', float)) + wrap_axes = [(i, a) for i, a in enumerate(assignment) if a is not None] + out.wrap(name, arr, wrap_axes) + + return out + + def sync_nested_amans(target, reference): + """Make `target` match `reference`'s field structure, using target's dets axis + and zero/empty values for anything missing.""" + for name, field in reference._fields.items(): + assignment = reference._assignments[name] + + if name not in target._fields: + # Field is entirely missing — build a zeroed version + if isinstance(field, core.AxisManager): + target.wrap(name, zeros_like_aman(field, target.dets)) + elif assignment is None or all(a is None for a in assignment): + target.wrap(name, field) # scalar metadata — share + elif isinstance(field, so3g.proj.RangesMatrix): + shape = tuple( + target.dets.count if a == 'dets' else reference._axes[a].count + for a in assignment if a is not None + ) + rm = so3g.proj.RangesMatrix.zeros(shape) + wrap_axes = [(i, a) for i, a in enumerate(assignment) if a is not None] + target.wrap(name, rm, wrap_axes) + else: + shape = tuple( + target.dets.count if a == 'dets' else reference._axes[a].count + for a in assignment if a is not None + ) + arr = np.zeros(shape, dtype=getattr(field, 'dtype', float)) + wrap_axes = [(i, a) for i, a in enumerate(assignment) if a is not None] + target.wrap(name, arr, wrap_axes) + elif isinstance(field, core.AxisManager): + # Both sides have it as a nested aman — recurse + sync_nested_amans(target._fields[name], field) + + # Apply + sync_nested_amans(tones_aman, obs) + + # Combine + L.debug('Combining') + obs = core.AxisManager.concatenate([obs, tones_aman], axis='dets', other_fields='first') + del tones_aman + if obs.dets.count < 50: L.debug("Skipped %s (Not enough detectors)" % (sub_id)) L.debug("Datacount: %s full" % (sub_id)) @@ -237,8 +358,49 @@ def main(**args): if np.any(zero_dets == 0.0): L.debug("%s has all 0s in at least 1 detector" % (sub_id)) obs.restrict('dets', obs.dets.vals[np.logical_not(zero_dets == 0.0)]) + + # Zero out T and P contributions to pointing matrix for fixed tones + obs.focal_plane.wrap('T', np.ones(obs.dets.count), [(0, obs.focal_plane.dets)]) + obs.focal_plane.wrap('P', np.ones(obs.dets.count), [(0, obs.focal_plane.dets)]) + if args.fixed_tones: + # Ensure gammas are zero + obs.focal_plane.gamma[obs.det_info.wafer.type == 'PROB'] = 0 + + # Set xi, eta to nearest matches + for idx in np.argwhere(obs.det_info.wafer.type == 'PROB').flatten(): + # Get possible channels + chans = obs.det_info.smurf.channel[(obs.det_info.wafer.type == 'OPTC') & \ + (obs.det_info.stream_id == obs.det_info.stream_id[idx]) & \ + (obs.det_info.smurf.band == obs.det_info.smurf.band[idx])] + + # Detectors on this band may have been filtered + if len(chans) == 0: + continue + + # Find closest match + opt_chan = chans[np.argmin(np.abs(chans - obs.det_info.smurf.channel[idx]))] + + # Get index + opt_idx = np.argwhere((obs.det_info.wafer.type == 'OPTC') & \ + (obs.det_info.stream_id == obs.det_info.stream_id[idx]) & \ + (obs.det_info.smurf.band == obs.det_info.smurf.band[idx]) & \ + (obs.det_info.smurf.channel == opt_chan)).flatten() + assert len(opt_idx) == 1 + + # Set xi, eta accordingly + obs.focal_plane.xi[idx] = obs.focal_plane.xi[opt_idx[0]] + obs.focal_plane.eta[idx] = obs.focal_plane.eta[opt_idx[0]] + + # Zero out T, P (not natively in focal plane) + obs.focal_plane.T[obs.det_info.wafer.type == 'PROB'] = 0 + obs.focal_plane.P[obs.det_info.wafer.type == 'PROB'] = 0 + # Cut non-optical dets, this will be redundant if the preprocessing already cut them - obs.restrict('dets', obs.dets.vals[obs.det_info.wafer.type == 'OPTC']) + # Keep fixed tones if desired + if args.fixed_tones: + obs.restrict('dets', obs.dets.vals[(obs.det_info.wafer.type == 'OPTC') | (obs.det_info.wafer.type == 'PROB')]) + else: + obs.restrict('dets', obs.dets.vals[obs.det_info.wafer.type == 'OPTC']) # Fix boresight mapmaking.fix_boresight_glitches(obs) # Get our sample rate. Would have been nice to have this available in the axisman @@ -272,6 +434,9 @@ def main(**args): rms = d1u.measure_rms(obs.signal, dt=1/srate) rms *= unit_defs[args.unit]/unit_defs["uK"] good = d1u.sensitivity_cut(rms, d1u.SENS_LIMITS[band]) + # Force fixed tones to satisfy + if args.fixed_tones: + good[obs.det_info.wafer.type == 'PROB'] = True if np.logical_not(good).sum() / obs.dets.count > 0.5: L.debug("Skipped %s (more than 50 percent of detectors cut by sens)" % (sub_id)) L.debug("Datacount: %s full" % (sub_id)) @@ -281,6 +446,9 @@ def main(**args): obs.restrict("dets", good) # Disqualify overly cut detectors good_dets = mapmaking.find_usable_detectors(obs, args.maxcut) + # Force fixed tones to satisfy + if args.fixed_tones: + good_dets = np.unique(np.concatenate((good_dets, obs.dets.vals[obs.det_info.wafer.type == 'PROB']))) obs.restrict("dets", good_dets) if obs.dets.count == 0: to_skip += [sub_id] @@ -316,6 +484,9 @@ def main(**args): nmat = mapmaking.read_nmat(nmat_file) else: nmat = None + L.debug('Add obs: %i detectors; %i optical, %i fixed' % \ + (obs.signal.shape[0], np.sum(obs.det_info.wafer.type == 'OPTC'), np.sum(obs.det_info.wafer.type == 'PROB'))) + # And add it to the mapmaker with bench.mark("add_obs %s" % sub_id): if ipass > 0: