Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Combining dev and main #155

Merged
merged 16 commits into from
Jan 27, 2025
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 1 addition & 36 deletions src/schedlib/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@


MIN_DURATION = 0.01
HWP_SPIN_DOWN = 15*u.minute
HWP_SPIN_UP = 7*u.minute
HWP_SPIN_DOWN = 15*u.minute

@dataclass_json
@dataclass(frozen=True)
Expand Down Expand Up @@ -120,41 +120,6 @@ def increment_time_sec(self, dt_sec: float) -> "State":
"""
return self.replace(curr_time=self.curr_time+dt.timedelta(seconds=dt_sec))


class SchedMode:
"""
Enumerate different options for scheduling operations in SATPolicy.

Attributes
----------
PreCal : str
'pre_cal'; Operations scheduled before block.t0 for calibration.
PreObs : str
'pre_obs'; Observations scheduled before block.t0 for observation.
InCal : str
'in_cal'; Calibration operations scheduled between block.t0 and block.t1.
InObs : str
'in_obs'; Observation operations scheduled between block.t0 and block.t1.
PostCal : str
'post_cal'; Calibration operations scheduled after block.t1.
PostObs : str
'post_obs'; Observations operations scheduled after block.t1.
PreSession : str
'pre_session'; Represents the start of a session, scheduled from the beginning of the requested t0.
PostSession : str
'post_session'; Indicates the end of a session, scheduled after the last operation.

"""
PreCal = 'pre_cal'
PreObs = 'pre_obs'
InCal = 'in_cal'
InObs = 'in_obs'
PostCal = 'post_cal'
PostObs = 'post_obs'
PreSession = 'pre_session'
PostSession = 'post_session'


# -------------------------------------------------------------------------
# Register operations
# -------------------------------------------------------------------------
Expand Down
278 changes: 274 additions & 4 deletions src/schedlib/policies/sat.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from .. import commands as cmd, instrument as inst, utils as u
from ..thirdparty import SunAvoidance
from .stages import get_build_stage
from .stages.build_op import get_parking

logger = u.init_logger(__name__)

Expand Down Expand Up @@ -69,6 +70,113 @@ class CalTarget:
az_speed: Optional[float]= None
az_accel: Optional[float] = None

@dataclass(frozen=True)
class WiregridTarget:
hour: int
el_target: float
az_target: float = 180
duration: float = 15*u.minute

class SchedMode:
"""
Enumerate different options for scheduling operations in SATPolicy.

Attributes
----------
PreCal : str
'pre_cal'; Operations scheduled before block.t0 for calibration.
PreObs : str
'pre_obs'; Observations scheduled before block.t0 for observation.
InCal : str
'in_cal'; Calibration operations scheduled between block.t0 and block.t1.
InObs : str
'in_obs'; Observation operations scheduled between block.t0 and block.t1.
PostCal : str
'post_cal'; Calibration operations scheduled after block.t1.
PostObs : str
'post_obs'; Observations operations scheduled after block.t1.
PreSession : str
'pre_session'; Represents the start of a session, scheduled from the beginning of the requested t0.
PostSession : str
'post_session'; Indicates the end of a session, scheduled after the last operation.

"""
PreCal = 'pre_cal'
PreObs = 'pre_obs'
InCal = 'in_cal'
InObs = 'in_obs'
PostCal = 'post_cal'
PostObs = 'post_obs'
PreSession = 'pre_session'
PostSession = 'post_session'
Wiregrid = 'wiregrid'

def make_cal_target(
source: str,
boresight: float,
elevation: float,
focus: str,
allow_partial=False,
drift=True,
az_branch=None,
az_speed=None,
az_accel=None,
) -> CalTarget:
array_focus = {
0 : {
'left' : 'ws3,ws2',
'middle' : 'ws0,ws1,ws4',
'right' : 'ws5,ws6',
'bottom': 'ws1,ws2,ws6',
'all' : 'ws0,ws1,ws2,ws3,ws4,ws5,ws6',
},
45 : {
'left' : 'ws3,ws4',
'middle' : 'ws2,ws0,ws5',
'right' : 'ws1,ws6',
'bottom': 'ws1,ws2,ws3',
'all' : 'ws0,ws1,ws2,ws3,ws4,ws5,ws6',
},
-45 : {
'left' : 'ws1,ws2',
'middle' : 'ws6,ws0,ws3',
'right' : 'ws4,ws5',
'bottom': 'ws1,ws6,ws5',
'all' : 'ws0,ws1,ws2,ws3,ws4,ws5,ws6',
},
}

boresight = float(boresight)
elevation = float(elevation)
focus = focus.lower()

focus_str = None
if int(boresight) not in array_focus:
logger.warning(
f"boresight not in {array_focus.keys()}, assuming {focus} is a wafer string"
)
focus_str = focus ##
else:
focus_str = array_focus[int(boresight)].get(focus, focus)

assert source in src.SOURCES, f"source should be one of {src.SOURCES.keys()}"

if az_branch is None:
az_branch = 180.

return CalTarget(
source=source,
array_query=focus_str,
el_bore=elevation,
boresight_rot=boresight,
tag=focus_str,
allow_partial=allow_partial,
drift=drift,
az_branch=az_branch,
az_speed=az_speed,
az_accel=az_accel,
)

# ----------------------------------------------------
# Register operations
# ----------------------------------------------------
Expand Down Expand Up @@ -268,7 +376,7 @@ def cmb_scan(state, block):
)
else:
commands = []

commands.extend([
"run.seq.scan(",
f" description='{block.name}',",
Expand Down Expand Up @@ -379,6 +487,12 @@ def bias_step(state, block, bias_step_cadence=None):
else:
return state, 0, []

@cmd.operation(name='sat.wiregrid', duration=15*u.minute)
def wiregrid(state):
return state, [
"run.wiregrid.calibrate(continuous=False, elevation_check=True, boresight_check=False, temperature_check=False)"
]

@dataclass
class SATPolicy:
"""a more realistic SAT policy.
Expand Down Expand Up @@ -447,6 +561,39 @@ def from_config(cls, config: Union[Dict[str, Any], str]):
config = yaml.load(config, Loader=loader)
return cls(**config)

def divide_blocks(self, block, max_dt=dt.timedelta(minutes=60), min_dt=dt.timedelta(minutes=15)):
duration = block.duration

# if the block is small enough, return it as is
if duration <= (max_dt + min_dt):
return [block]

n_blocks = duration // max_dt
remainder = duration % max_dt

# split if 1 block with remainder > min duration
if n_blocks == 1:
return core.block_split(block, block.t0 + max_dt)

blocks = []
# calculate the offset for splitting
offset = (remainder + max_dt) / 2 if remainder.total_seconds() > 0 else max_dt

split_blocks = core.block_split(block, block.t0 + offset)
blocks.append(split_blocks[0])

# split the remaining block into chunks of max duration
for i in range(n_blocks - 1):
split_blocks = core.block_split(split_blocks[-1], split_blocks[-1].t0 + max_dt)
blocks.append(split_blocks[0])

# add the remaining part
if remainder.total_seconds() > 0:
split_blocks = core.block_split(split_blocks[-1], split_blocks[-1].t0 + offset)
blocks.append(split_blocks[0])

return blocks

def init_seqs(self, t0: dt.datetime, t1: dt.datetime) -> core.BlocksTree:
"""
Initialize the sequences for the scheduler to process.
Expand Down Expand Up @@ -494,6 +641,26 @@ def construct_seq(loader_cfg):
source = cal_target.source
if source not in blocks['calibration']:
blocks['calibration'][source] = src.source_gen_seq(source, t0, t1)
elif isinstance(cal_target, WiregridTarget):
wiregrid_candidates = []
current_date = t0.date()
end_date = t1.date()

while current_date <= end_date:
candidate_time = dt.datetime.combine(current_date, dt.time(cal_target.hour, 0), tzinfo=dt.timezone.utc)
if t0 <= candidate_time <= t1:
wiregrid_candidates.append(
inst.StareBlock(
name='wiregrid',
t0=candidate_time,
t1=candidate_time + dt.timedelta(seconds=cal_target.duration),
az=cal_target.az_target,
alt=cal_target.el_target,
subtype='wiregrid'
)
)
current_date += dt.timedelta(days=1)
blocks['calibration']['wiregrid'] = wiregrid_candidates

# trim to given time range
blocks = core.seq_trim(blocks, t0, t1)
Expand Down Expand Up @@ -548,7 +715,13 @@ def apply(self, blocks: core.BlocksTree) -> core.BlocksTree:

for target in self.cal_targets:
logger.info(f"-> planning calibration scans for {target}...")


if isinstance(target, WiregridTarget):
logger.info(f"-> planning wiregrid scans for {target}...")
cal_blocks += core.seq_map(lambda b: b.replace(subtype='wiregrid'),
blocks['calibration']['wiregrid'])
continue

assert target.source in blocks['calibration'], f"source {target.source} not found in sequence"

# digest array_query: it could be a fnmatch pattern matching the path
Expand Down Expand Up @@ -643,7 +816,7 @@ def apply(self, blocks: core.BlocksTree) -> core.BlocksTree:

# add proper subtypes
blocks['calibration'] = core.seq_map(
lambda block: block.replace(subtype="cal"),
lambda block: block.replace(subtype="cal") if block.name != 'wiregrid' else block,
blocks['calibration']
)

Expand Down Expand Up @@ -729,7 +902,99 @@ def seq2cmd(

# load building stage
build_op = get_build_stage('build_op', {'policy_config': self, **self.stages.get('build_op', {})})
ops, state = build_op.apply(seq, t0, t1, state, self.operations)

# first resolve overlapping between cal and cmb
cal_blocks = core.seq_flatten(core.seq_filter(lambda b: b.subtype == 'cal', seq))
cmb_blocks = core.seq_flatten(core.seq_filter(lambda b: b.subtype == 'cmb', seq))
wiregrid_blocks = core.seq_flatten(core.seq_filter(lambda b: b.subtype == 'wiregrid', seq))
cal_blocks += wiregrid_blocks
seq = core.seq_sort(core.seq_merge(cmb_blocks, cal_blocks, flatten=True))

# divide cmb blocks
if self.max_cmb_scan_duration is not None:
seq = core.seq_flatten(core.seq_map(lambda b: self.divide_blocks(b, dt.timedelta(seconds=self.max_cmb_scan_duration)) if b.subtype=='cmb' else b, seq))

# compile operations
cal_pre = [op for op in self.operations if op['sched_mode'] == SchedMode.PreCal]
cal_in = [op for op in self.operations if op['sched_mode'] == SchedMode.InCal]
cal_post = [op for op in self.operations if op['sched_mode'] == SchedMode.PostCal]
cmb_pre = [op for op in self.operations if op['sched_mode'] == SchedMode.PreObs]
cmb_in = [op for op in self.operations if op['sched_mode'] == SchedMode.InObs]
cmb_post = [op for op in self.operations if op['sched_mode'] == SchedMode.PostObs]
pre_sess = [op for op in self.operations if op['sched_mode'] == SchedMode.PreSession]
pos_sess = [op for op in self.operations if op['sched_mode'] == SchedMode.PostSession]
wiregrid_in = [op for op in self.operations if op['sched_mode'] == SchedMode.Wiregrid]

def map_block(block):
if block.subtype == 'cal':
return {
'name': block.name,
'block': block,
'pre': cal_pre,
'in': cal_in,
'post': cal_post,
'priority': 3
}
elif block.subtype == 'cmb':
return {
'name': block.name,
'block': block,
'pre': cmb_pre,
'in': cmb_in,
'post': cmb_post,
'priority': 1
}
elif block.subtype == 'wiregrid':
return {
'name': block.name,
'block': block,
'pre': [],
'in': wiregrid_in,
'post': [],
'priority': 2
}
else:
raise ValueError(f"unexpected block subtype: {block.subtype}")

seq = [map_block(b) for b in seq]
start_block = {
'name': 'pre-session',
'block': inst.StareBlock(name="pre-session", az=state.az_now, alt=state.el_now, t0=t0, t1=t0+dt.timedelta(seconds=1)),
'pre': [],
'in': [],
'post': pre_sess, # scheduled after t0
'priority': 3,
'pinned': True # remain unchanged during multi-pass
}
# move to stow position if specified, otherwise keep final position
if len(pos_sess) > 0:
# find an alt, az that is sun-safe for the entire duration of the schedule.
if not self.stages['build_op']['plan_moves']['stow_position']:
az_start = 180
alt_start = 60
# add a buffer to start and end to be safe
t_start = t0 - dt.timedelta(seconds=300)
t_end = t1 + dt.timedelta(seconds=300)
az_stow, alt_stow, _, _ = get_parking(t_start, t_end, alt_start, self.stages['build_op']['plan_moves']['sun_policy'])
logger.info(f"found sun safe stow position at ({az_stow}, {alt_stow})")
else:
az_stow = self.stages['build_op']['plan_moves']['stow_position']['az_stow']
alt_stow = self.stages['build_op']['plan_moves']['stow_position']['el_stow']
else:
az_stow = seq[-1]['block'].az
alt_stow = seq[-1]['block'].alt
end_block = {
'name': 'post-session',
'block': inst.StareBlock(name="post-session", az=az_stow, alt=alt_stow, t0=t1-dt.timedelta(seconds=1), t1=t1),
'pre': pos_sess, # scheduled before t1
'in': [],
'post': [],
'priority': 3,
'pinned': True # remain unchanged during multi-pass
}
seq = [start_block] + seq + [end_block]

ops, state = build_op.apply(seq, t0, t1, state)
if return_state:
return ops, state
return ops
Expand Down Expand Up @@ -791,6 +1056,11 @@ def build_schedule(self, t0: dt.datetime, t1: dt.datetime, state: State = None):

return schedule

def add_cal_target(self, *args, **kwargs):
self.cal_targets.append(make_cal_target(*args, **kwargs))

def add_wiregrid_target(self, el_target, hour_utc=12, az_target=180, duration=15*u.minute, **kwargs):
self.cal_targets.append(WiregridTarget(hour=hour_utc, az_target=az_target, el_target=el_target, duration=duration))

# ------------------------
# utilities
Expand Down
Loading