diff --git a/docs/config-reference.md b/docs/config-reference.md index 8a6687a..b51658d 100644 --- a/docs/config-reference.md +++ b/docs/config-reference.md @@ -122,6 +122,24 @@ Semantic context stage: This stage builds a lightweight repository index of files, Python symbols, imports, and tests, then writes compact relevant snippets for the current task. It is keyword based with symbol-aware scoring, so it works without a vector database or network dependency. +### `on_status` Stage Routing + +Instead of a single `on_fail` catch-all, use `on_status` to route each review status to a different stage: + +```yaml +- id: review + type: agent_review + agent: reviewer + output: review.md + on_status: + pass: summarize + retry: implement + fail: plan + escalate: human +``` + +`on_status` supports `pass`, `fail`, `retry`, and `escalate` keys. For `pass`, it overrides sequential progression and any agent-supplied `next_stage`. For non-pass statuses, the lookup order is: `on_status[status]` → `on_fail` → `next_stage` (agent output). + ## Failure, Retry, and Resource Artifacts Failed command and validation stages write deterministic diagnostics under the task artifact directory: diff --git a/nightshift/config.py b/nightshift/config.py index 15041a7..f3a9c8e 100644 --- a/nightshift/config.py +++ b/nightshift/config.py @@ -61,7 +61,7 @@ class StageConfig: commands: tuple[str, ...] = () output: str | None = None on_fail: str | None = None - on_pass: str | None = None + on_status: dict[str, str] | None = None shell: bool = True timeout_seconds: int | None = None working_dir: Path | None = None @@ -393,7 +393,7 @@ def parse_config(raw: dict[str, Any], config_path: Path) -> NightShiftConfig: commands=commands, output=_optional_string(stage_raw.get("output"), f"{stage_context}.output"), on_fail=_optional_string(stage_raw.get("on_fail"), f"{stage_context}.on_fail"), - on_pass=_optional_string(stage_raw.get("on_pass"), f"{stage_context}.on_pass"), + on_status=_parse_on_status(stage_raw, stage_context), shell=_optional_bool(stage_raw.get("shell", True), f"{stage_context}.shell"), timeout_seconds=timeout_seconds, working_dir=Path(working_dir_raw) if working_dir_raw else None, @@ -418,10 +418,13 @@ def parse_config(raw: dict[str, Any], config_path: Path) -> NightShiftConfig: raise ConfigError( f"Config error: stage '{stage.id}' on_fail references unknown stage '{stage.on_fail}'." ) - if stage.on_pass and stage.on_pass not in stage_ids: - raise ConfigError( - f"Config error: stage '{stage.id}' on_pass references unknown stage '{stage.on_pass}'." - ) + if stage.on_status: + for status_key, target in stage.on_status.items(): + if target not in stage_ids: + raise ConfigError( + f"Config error: stage '{stage.id}' on_status.{status_key} " + f"references unknown stage '{target}'." + ) return NightShiftConfig( path=config_path, @@ -635,3 +638,27 @@ def _string_tuple(value: Any, context: str) -> tuple[str, ...]: if not isinstance(value, list) or not all(isinstance(item, str) and item for item in value): raise ConfigError(f"Config error: '{context}' must be a list of non-empty strings.") return tuple(value) + + +VALID_STATUS_KEYS = frozenset({"pass", "fail", "retry", "escalate"}) + + +def _parse_on_status(raw: dict[str, Any], context: str) -> dict[str, str] | None: + on_status_raw = raw.get("on_status") + if on_status_raw is None: + return None + if not isinstance(on_status_raw, dict): + raise ConfigError(f"Config error: {context}.on_status must be a mapping.") + on_status: dict[str, str] = {} + for key, value in on_status_raw.items(): + if key not in VALID_STATUS_KEYS: + raise ConfigError( + f"Config error: {context}.on_status invalid key '{key}'. " + f"Valid keys: {', '.join(sorted(VALID_STATUS_KEYS))}." + ) + if not isinstance(value, str) or not value: + raise ConfigError( + f"Config error: {context}.on_status.{key} must be a non-empty string." + ) + on_status[key] = value + return on_status diff --git a/nightshift/pipeline.py b/nightshift/pipeline.py index 4a4185d..afec69d 100644 --- a/nightshift/pipeline.py +++ b/nightshift/pipeline.py @@ -200,22 +200,41 @@ class PipelineRunner: retry_notes.append(f"Context update from '{stage.id}': {result.context_update}") if result.status == "pass": - pass_target_stage = result.next_stage or stage.on_pass - if stage.type in {"agent_review", "review"} and result.next_stage: + if stage.on_status and "pass" in stage.on_status: + target = stage.on_status["pass"] + if target not in stage_indexes: + final_status = "failed" + final_reason = ( + f"Stage '{stage.id}' on_status.pass references unknown stage '{target}'." + ) + break self.logger.event( - "stage.next_ignored", - "Ignoring next_stage from passing review", + "stage.next", + "Jumping via on_status.pass", run_id=self.artifacts.run_id, task_id=task.id, stage_id=stage.id, - requested_next_stage=result.next_stage, + next_stage=target, ) - pass_target_stage = stage.on_pass - if pass_target_stage: - if pass_target_stage not in stage_indexes: + index = stage_indexes[target] + continue + if stage.type in {"agent_review", "review"}: + if result.next_stage: + self.logger.event( + "stage.next_ignored", + "Ignoring next_stage from passing review", + run_id=self.artifacts.run_id, + task_id=task.id, + stage_id=stage.id, + requested_next_stage=result.next_stage, + ) + index += 1 + continue + if result.next_stage: + if result.next_stage not in stage_indexes: final_status = "failed" final_reason = ( - f"Stage '{stage.id}' requested unknown next stage '{pass_target_stage}'." + f"Stage '{stage.id}' requested unknown next stage '{result.next_stage}'." ) break self.logger.event( @@ -224,14 +243,14 @@ class PipelineRunner: run_id=self.artifacts.run_id, task_id=task.id, stage_id=stage.id, - next_stage=pass_target_stage, + next_stage=result.next_stage, ) - index = stage_indexes[pass_target_stage] + index = stage_indexes[result.next_stage] continue index += 1 continue - target_stage = _failure_target_stage(stage, result) + target_stage = _resolve_retry_target_stage(stage, result) analysis_note = self._write_failure_diagnostics(stage, task, result, retry_count) if analysis_note: retry_notes.append(analysis_note) @@ -1840,14 +1859,10 @@ def _is_malformed_review_result(result: StageResult) -> bool: ) -def _failure_target_stage(stage: StageConfig, result: StageResult) -> str | None: - if stage.type not in {"agent_review", "review"}: - return result.next_stage or stage.on_fail - if _is_malformed_review_result(result): +def _resolve_retry_target_stage(stage: StageConfig, result: StageResult) -> str | None: + if stage.type in {"agent_review", "review"} and _is_malformed_review_result(result): return None - if result.next_stage and result.next_stage != stage.id: - return result.next_stage - return stage.on_fail + return (stage.on_status or {}).get(result.status) or stage.on_fail or result.next_stage def _previous_continuity_review_passed(previous_outputs: dict[str, str]) -> bool: diff --git a/nightshift/templates.py b/nightshift/templates.py index d5ed071..f574c49 100644 --- a/nightshift/templates.py +++ b/nightshift/templates.py @@ -68,6 +68,12 @@ pipeline: - id: review type: agent_review agent: reviewer + # on_fail: implement # catch-all for any non-pass status + # on_status: # per-status routing (takes priority over on_fail) + # pass: summarize + # retry: implement + # fail: plan + # escalate: human on_fail: implement output: review.md diff --git a/tests/test_config.py b/tests/test_config.py index 8d868d8..25d7f64 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -55,39 +55,57 @@ class ConfigTests(unittest.TestCase): with self.assertRaisesRegex(ConfigError, "on_fail references unknown stage"): load_config(config_path) - def test_on_pass_must_reference_existing_stage(self) -> None: + def test_on_status_parses_correctly(self) -> None: with tempfile.TemporaryDirectory() as directory: root = Path(directory) init_project(root) config_path = root / "nightshift.yaml" - config_path.write_text( - config_path.read_text(encoding="utf-8").replace( - "on_fail: plan", "on_pass: missing_stage", 1 - ), - encoding="utf-8", - ) - - with self.assertRaisesRegex(ConfigError, "on_pass references unknown stage"): - load_config(config_path) - - def test_on_pass_loads(self) -> None: - with tempfile.TemporaryDirectory() as directory: - root = Path(directory) - init_project(root) - config_path = root / "nightshift.yaml" - config_path.write_text( - config_path.read_text(encoding="utf-8").replace( - " output: plan.md", - " output: plan.md\n on_pass: summarize", - 1, - ), - encoding="utf-8", + text = config_path.read_text(encoding="utf-8") + text = text.replace( + " on_fail: implement\n output: review.md", + " output: review.md\n on_status:\n pass: summarize\n retry: implement\n fail: plan", ) + config_path.write_text(text, encoding="utf-8") config = load_config(config_path) - plan_stage = next(stage for stage in config.pipeline.stages if stage.id == "plan") + review_stage = next(s for s in config.pipeline.stages if s.id == "review") - self.assertEqual(plan_stage.on_pass, "summarize") + self.assertEqual(review_stage.on_status, { + "pass": "summarize", + "retry": "implement", + "fail": "plan", + }) + self.assertIsNone(review_stage.on_fail) + + def test_on_status_rejects_invalid_key(self) -> None: + with tempfile.TemporaryDirectory() as directory: + root = Path(directory) + init_project(root) + config_path = root / "nightshift.yaml" + text = config_path.read_text(encoding="utf-8") + text = text.replace( + " on_fail: implement\n output: review.md", + " output: review.md\n on_status:\n wat: broken", + ) + config_path.write_text(text, encoding="utf-8") + + with self.assertRaisesRegex(ConfigError, "on_status invalid key"): + load_config(config_path) + + def test_on_status_references_unknown_stage(self) -> None: + with tempfile.TemporaryDirectory() as directory: + root = Path(directory) + init_project(root) + config_path = root / "nightshift.yaml" + text = config_path.read_text(encoding="utf-8") + text = text.replace( + " on_fail: implement\n output: review.md", + " output: review.md\n on_status:\n fail: missing_stage", + ) + config_path.write_text(text, encoding="utf-8") + + with self.assertRaisesRegex(ConfigError, "on_status.fail references unknown stage"): + load_config(config_path) def test_validate_requires_prompt_files(self) -> None: with tempfile.TemporaryDirectory() as directory: @@ -371,6 +389,39 @@ class ConfigTests(unittest.TestCase): with self.assertRaisesRegex(ConfigError, "non-command stage 'plan'"): load_config(config_path) + def test_on_status_empty_key_fails(self) -> None: + with tempfile.TemporaryDirectory() as directory: + root = Path(directory) + init_project(root) + config_path = root / "nightshift.yaml" + text = config_path.read_text(encoding="utf-8") + text = text.replace( + " on_fail: implement\n output: review.md", + " output: review.md\n on_status:\n pass: ", + ) + config_path.write_text(text, encoding="utf-8") + + with self.assertRaisesRegex(ConfigError, "must be a non-empty string"): + load_config(config_path) + + def test_on_fail_fallback_when_on_status_does_not_cover_status(self) -> None: + with tempfile.TemporaryDirectory() as directory: + root = Path(directory) + init_project(root) + config_path = root / "nightshift.yaml" + text = config_path.read_text(encoding="utf-8") + text = text.replace( + " on_fail: implement\n output: review.md", + " output: review.md\n on_status:\n pass: summarize\n on_fail: implement", + ) + config_path.write_text(text, encoding="utf-8") + + config = load_config(config_path) + review_stage = next(s for s in config.pipeline.stages if s.id == "review") + + self.assertEqual(review_stage.on_status, {"pass": "summarize"}) + self.assertEqual(review_stage.on_fail, "implement") + if __name__ == "__main__": unittest.main() diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py index c82dc3e..61dcae3 100644 --- a/tests/test_pipeline.py +++ b/tests/test_pipeline.py @@ -105,29 +105,145 @@ class PipelineRunnerTests(unittest.TestCase): ) self.assertIn("Modified Files", (root / ".nightshift" / "runs" / "test-run" / "run-summary.md").read_text(encoding="utf-8")) - def test_on_pass_jumps_to_configured_stage(self) -> None: + def test_on_status_routes_pass_to_target(self) -> None: with tempfile.TemporaryDirectory() as directory: root = Path(directory) _write_common_files(root) stages = ( - StageConfig(id="first", type="agent", agent="planner", output="first.md", on_pass="third"), + StageConfig(id="plan", type="agent", agent="planner", output="plan.md"), StageConfig( - id="second", - type="command", - commands=('python -c "print(\'should not run\')"',), - output="second-output.txt", + id="review", + type="agent_review", + agent="reviewer", + on_status={"pass": "summarize"}, + output="review.md", ), - StageConfig(id="third", type="summarize", output="final-notes.md"), + StageConfig(id="implement", type="agent", agent="planner", output="impl.md"), + StageConfig(id="summarize", type="summarize", output="final-notes.md"), ) config = make_config(root, stages) runner = PipelineRunner(config, ArtifactStore(root, ".nightshift", run_id="test-run")) + task = parse_tasks(TASK_MD)[0] - result = runner.run_task(parse_tasks(TASK_MD)[0]) + result = runner.run_task(task) - task_dir = root / ".nightshift" / "runs" / "test-run" / "tasks" / "TASK-001" self.assertEqual(result.status, "complete") - self.assertEqual([item.stage_id for item in result.stage_results], ["first", "third"]) - self.assertFalse((task_dir / "second-output.txt").exists()) + self.assertEqual(result.retry_count, 0) + self.assertEqual( + [r.stage_id for r in result.stage_results], + ["plan", "review", "summarize"], + ) + + def test_on_status_routes_fail_to_target(self) -> None: + with tempfile.TemporaryDirectory() as directory: + root = Path(directory) + _write_common_files(root) + fail_reviewer = 'python -c "print(\'status: fail\\nreason: bad plan\')"' + stages = ( + StageConfig(id="plan", type="agent", agent="planner", output="plan.md"), + StageConfig( + id="review", + type="agent_review", + agent="reviewer", + on_status={"fail": "plan"}, + output="review.md", + ), + StageConfig(id="summarize", type="summarize", output="final-notes.md"), + ) + config = make_config(root, stages) + config.agents["reviewer"] = AgentConfig( + id="reviewer", + backend="command", + command=fail_reviewer, + system_prompt=Path("reviewer.md"), + ) + runner = PipelineRunner(config, ArtifactStore(root, ".nightshift", run_id="test-run")) + task = parse_tasks(TASK_MD)[0] + + result = runner.run_task(task) + + self.assertEqual(result.status, "failed") + self.assertEqual(result.retry_count, 2) + self.assertEqual( + [r.stage_id for r in result.stage_results], + ["plan", "review", "plan", "review", "plan", "review"], + ) + + def test_on_status_escalate_routes_to_human_not_on_fail(self) -> None: + with tempfile.TemporaryDirectory() as directory: + root = Path(directory) + _write_common_files(root) + escalate_reviewer = 'python -c "print(\'status: escalate\\nreason: need human\')"' + stages = ( + StageConfig(id="plan", type="agent", agent="planner", output="plan.md"), + StageConfig( + id="review", + type="agent_review", + agent="reviewer", + on_status={ + "retry": "plan", + "escalate": "human", + }, + on_fail="plan", + output="review.md", + ), + StageConfig(id="human", type="summarize", output="human-notes.md"), + StageConfig(id="summarize", type="summarize", output="final-notes.md"), + ) + config = make_config(root, stages) + config.agents["reviewer"] = AgentConfig( + id="reviewer", + backend="command", + command=escalate_reviewer, + system_prompt=Path("reviewer.md"), + ) + runner = PipelineRunner(config, ArtifactStore(root, ".nightshift", run_id="test-run")) + task = parse_tasks(TASK_MD)[0] + + result = runner.run_task(task) + + self.assertEqual(result.status, "complete") + self.assertEqual(result.retry_count, 1) + self.assertEqual( + [r.stage_id for r in result.stage_results], + ["plan", "review", "human", "summarize"], + ) + + def test_on_fail_fallback_when_status_not_in_on_status(self) -> None: + with tempfile.TemporaryDirectory() as directory: + root = Path(directory) + _write_common_files(root) + fail_reviewer = 'python -c "print(\'status: fail\\nreason: bad\')"' + stages = ( + StageConfig(id="plan", type="agent", agent="planner", output="plan.md"), + StageConfig( + id="review", + type="agent_review", + agent="reviewer", + on_status={"retry": "plan"}, + on_fail="implement", + output="review.md", + ), + StageConfig(id="implement", type="agent", agent="planner", output="impl.md"), + ) + config = make_config(root, stages) + config.agents["reviewer"] = AgentConfig( + id="reviewer", + backend="command", + command=fail_reviewer, + system_prompt=Path("reviewer.md"), + ) + runner = PipelineRunner(config, ArtifactStore(root, ".nightshift", run_id="test-run")) + task = parse_tasks(TASK_MD)[0] + + result = runner.run_task(task) + + self.assertEqual(result.status, "failed") + self.assertEqual(result.retry_count, 2) + self.assertEqual( + [r.stage_id for r in result.stage_results], + ["plan", "review", "implement", "review", "implement", "review"], + ) def test_task_preflight_fails_when_task_specific_test_file_is_missing(self) -> None: with tempfile.TemporaryDirectory() as directory: