mirror of
https://github.com/khodges42/nightShift.git
synced 2026-06-14 18:18:36 +00:00
494 lines
16 KiB
Python
494 lines
16 KiB
Python
"""Unified diff extraction and validation."""
|
|
|
|
from __future__ import annotations
|
|
|
|
from dataclasses import dataclass
|
|
import difflib
|
|
from pathlib import Path
|
|
import re
|
|
import subprocess
|
|
|
|
from .config import SafetyConfig
|
|
from .errors import PipelineError, SafetyError
|
|
from .safety import resolve_inside_root, resolve_project_root, validate_scoped_paths
|
|
|
|
|
|
DEFAULT_MAX_FILES = 20
|
|
DEFAULT_MAX_CHANGED_LINES = 2000
|
|
DEFAULT_FORBIDDEN_PATHS = (".git", ".nightshift", ".env")
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class PatchValidationResult:
|
|
files: tuple[str, ...]
|
|
changed_lines: int
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class PatchApplyResult:
|
|
status: str
|
|
command: str
|
|
exit_code: int
|
|
stdout: str
|
|
stderr: str
|
|
mode: str
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class FileUpdate:
|
|
path: str
|
|
content: str
|
|
|
|
|
|
def extract_unified_diff(text: str) -> str:
|
|
fenced = re.search(r"```(?:diff|patch)?\s*\n(.*?)```", text, flags=re.DOTALL | re.IGNORECASE)
|
|
candidate = fenced.group(1) if fenced else text
|
|
lines = candidate.splitlines()
|
|
start = next((index for index, line in enumerate(lines) if line.startswith("diff --git ")), None)
|
|
if start is None:
|
|
start = next((index for index, line in enumerate(lines) if line.startswith("--- ")), None)
|
|
if start is None:
|
|
raise PipelineError("Patch error: no unified diff found.")
|
|
patch = "\n".join(lines[start:]).strip()
|
|
if not patch:
|
|
raise PipelineError("Patch error: unified diff is empty.")
|
|
return patch + "\n"
|
|
|
|
|
|
def normalize_patch_text(text: str) -> str:
|
|
patch = extract_unified_diff(text)
|
|
if "@@" not in patch:
|
|
raise PipelineError("Patch error: unified diff has no hunks.")
|
|
return repair_hunk_counts(patch)
|
|
|
|
|
|
def repair_hunk_counts(patch: str) -> str:
|
|
"""Rewrite unified diff hunk counts from the actual hunk body."""
|
|
|
|
lines = patch.splitlines()
|
|
repaired: list[str] = []
|
|
index = 0
|
|
while index < len(lines):
|
|
line = lines[index]
|
|
if not line.startswith("@@"):
|
|
repaired.append(line)
|
|
index += 1
|
|
continue
|
|
|
|
body: list[str] = []
|
|
body_index = index + 1
|
|
while body_index < len(lines):
|
|
next_line = lines[body_index]
|
|
if next_line.startswith("@@") or next_line.startswith("diff --git "):
|
|
break
|
|
body.append(next_line)
|
|
body_index += 1
|
|
repaired.append(_format_hunk_header(line, body, index + 1))
|
|
repaired.extend(body)
|
|
index = body_index
|
|
return "\n".join(repaired).rstrip() + "\n"
|
|
|
|
|
|
def parse_file_updates(text: str) -> tuple[FileUpdate, ...]:
|
|
"""Parse model-supplied complete file content blocks."""
|
|
|
|
updates: list[FileUpdate] = []
|
|
pattern = re.compile(
|
|
r"```(?:file|path)[:=](?P<path>[^\n`]+)\n(?P<content>.*?)```",
|
|
flags=re.DOTALL | re.IGNORECASE,
|
|
)
|
|
for match in pattern.finditer(text):
|
|
path = match.group("path").strip()
|
|
content = match.group("content")
|
|
if not path:
|
|
continue
|
|
updates.append(FileUpdate(path=path, content=content))
|
|
if not updates:
|
|
raise PipelineError(
|
|
"File writer error: no file blocks found. Expected fenced blocks like ```file:path.py."
|
|
)
|
|
return tuple(updates)
|
|
|
|
|
|
def generate_patch_from_file_updates(
|
|
updates: tuple[FileUpdate, ...],
|
|
project_root: str | Path,
|
|
safety: SafetyConfig,
|
|
forbidden_paths: tuple[str, ...] = DEFAULT_FORBIDDEN_PATHS,
|
|
) -> str:
|
|
root = resolve_project_root(project_root)
|
|
scoped_roots = validate_scoped_paths(root, safety.scoped_paths or (".",))
|
|
patch_parts: list[str] = []
|
|
seen: set[str] = set()
|
|
for update in updates:
|
|
normalized_path = _normalize_update_path(update.path)
|
|
if normalized_path in seen:
|
|
raise PipelineError(f"File writer error: duplicate file block `{normalized_path}`.")
|
|
seen.add(normalized_path)
|
|
_validate_patch_path(normalized_path, root, scoped_roots, forbidden_paths)
|
|
file_path = resolve_inside_root(root, normalized_path, f"file update '{normalized_path}'")
|
|
old_text = file_path.read_text(encoding="utf-8", errors="replace") if file_path.exists() else ""
|
|
if old_text == update.content:
|
|
continue
|
|
patch_parts.extend(_diff_for_file(normalized_path, old_text, update.content, file_path.exists()))
|
|
if not patch_parts:
|
|
raise PipelineError("File writer error: generated patch has no changes.")
|
|
return "\n".join(patch_parts).rstrip() + "\n"
|
|
|
|
|
|
def validate_patch(
|
|
patch: str,
|
|
project_root: str | Path,
|
|
safety: SafetyConfig,
|
|
max_files: int = DEFAULT_MAX_FILES,
|
|
max_changed_lines: int = DEFAULT_MAX_CHANGED_LINES,
|
|
forbidden_paths: tuple[str, ...] = DEFAULT_FORBIDDEN_PATHS,
|
|
) -> PatchValidationResult:
|
|
root = resolve_project_root(project_root)
|
|
scoped_roots = validate_scoped_paths(root, safety.scoped_paths or (".",))
|
|
files = _patch_files(patch)
|
|
if not files:
|
|
raise PipelineError("Patch validation failed: no changed files found.")
|
|
if len(files) > max_files:
|
|
raise PipelineError(f"Patch validation failed: touches {len(files)} files, max is {max_files}.")
|
|
|
|
changed_lines = _changed_line_count(patch)
|
|
if changed_lines <= 0:
|
|
raise PipelineError("Patch validation failed: patch has no changed lines.")
|
|
if changed_lines > max_changed_lines:
|
|
raise PipelineError(
|
|
f"Patch validation failed: changes {changed_lines} lines, max is {max_changed_lines}."
|
|
)
|
|
|
|
for path_text in files:
|
|
_validate_patch_path(path_text, root, scoped_roots, forbidden_paths)
|
|
_validate_hunk_lines(patch)
|
|
_validate_hunk_counts(patch)
|
|
_validate_file_states(patch, root)
|
|
return PatchValidationResult(files=tuple(sorted(files)), changed_lines=changed_lines)
|
|
|
|
|
|
def format_validation_result(result: PatchValidationResult) -> str:
|
|
return "\n".join(
|
|
[
|
|
"# Patch Validation",
|
|
"",
|
|
"Status: pass",
|
|
f"Changed files: {len(result.files)}",
|
|
f"Changed lines: {result.changed_lines}",
|
|
"",
|
|
"## Files",
|
|
"",
|
|
*[f"- `{path}`" for path in result.files],
|
|
"",
|
|
]
|
|
)
|
|
|
|
|
|
def apply_patch_with_git(patch_path: Path, project_root: str | Path, mode: str = "dry_run") -> PatchApplyResult:
|
|
root = resolve_project_root(project_root)
|
|
command = ["git", "apply", "--ignore-whitespace", "--check", str(patch_path)]
|
|
if mode == "apply":
|
|
command = ["git", "apply", "--ignore-whitespace", str(patch_path)]
|
|
completed = subprocess.run(
|
|
command,
|
|
cwd=root,
|
|
capture_output=True,
|
|
text=True,
|
|
encoding="utf-8",
|
|
errors="replace",
|
|
)
|
|
return PatchApplyResult(
|
|
status="pass" if completed.returncode == 0 else "fail",
|
|
command=" ".join(command),
|
|
exit_code=completed.returncode,
|
|
stdout=completed.stdout or "",
|
|
stderr=completed.stderr or "",
|
|
mode=mode,
|
|
)
|
|
|
|
|
|
def format_patch_apply_result(result: PatchApplyResult, patch_path: str) -> str:
|
|
return "\n".join(
|
|
[
|
|
"# Patch Apply",
|
|
"",
|
|
f"Status: {result.status}",
|
|
f"Mode: {result.mode}",
|
|
f"Patch: `{patch_path}`",
|
|
f"Command: `{result.command}`",
|
|
f"Exit code: {result.exit_code}",
|
|
"",
|
|
"## stdout",
|
|
"",
|
|
"```text",
|
|
result.stdout.rstrip(),
|
|
"```",
|
|
"",
|
|
"## stderr",
|
|
"",
|
|
"```text",
|
|
result.stderr.rstrip(),
|
|
"```",
|
|
"",
|
|
]
|
|
)
|
|
|
|
|
|
def _patch_files(patch: str) -> set[str]:
|
|
files: set[str] = set()
|
|
saw_hunk = False
|
|
for line in patch.splitlines():
|
|
if line.startswith("@@"):
|
|
saw_hunk = True
|
|
if line.startswith("diff --git "):
|
|
parts = line.split()
|
|
if len(parts) >= 4:
|
|
files.add(_strip_prefix(parts[3]))
|
|
elif line.startswith("+++ "):
|
|
target = line[4:].strip()
|
|
if target != "/dev/null":
|
|
files.add(_strip_prefix(target))
|
|
elif line.startswith("--- "):
|
|
source = line[4:].strip()
|
|
if source != "/dev/null":
|
|
files.add(_strip_prefix(source))
|
|
if not saw_hunk:
|
|
raise PipelineError("Patch validation failed: unified diff has no hunk headers.")
|
|
return {path for path in files if path}
|
|
|
|
|
|
def _validate_hunk_lines(patch: str) -> None:
|
|
in_hunk = False
|
|
for line_number, line in enumerate(patch.splitlines(), start=1):
|
|
if line.startswith("diff --git "):
|
|
in_hunk = False
|
|
continue
|
|
if line.startswith("@@"):
|
|
in_hunk = True
|
|
continue
|
|
if not in_hunk:
|
|
continue
|
|
if line.startswith(("+", "-", " ", "\\")):
|
|
continue
|
|
raise PipelineError(
|
|
"Patch validation failed: malformed hunk line "
|
|
f"{line_number}; expected a leading space, '+', '-', or backslash."
|
|
)
|
|
|
|
|
|
def _validate_hunk_counts(patch: str) -> None:
|
|
current: dict[str, int] | None = None
|
|
|
|
def flush(line_number: int) -> None:
|
|
if current is None:
|
|
return
|
|
old_expected = current["old_expected"]
|
|
new_expected = current["new_expected"]
|
|
old_actual = current["old_actual"]
|
|
new_actual = current["new_actual"]
|
|
hunk_line = current["line_number"]
|
|
if old_actual != old_expected:
|
|
raise PipelineError(
|
|
"Patch validation failed: hunk starting at line "
|
|
f"{hunk_line} old line count expected {old_expected}, got {old_actual} "
|
|
f"before line {line_number}."
|
|
)
|
|
if new_actual != new_expected:
|
|
raise PipelineError(
|
|
"Patch validation failed: hunk starting at line "
|
|
f"{hunk_line} new line count expected {new_expected}, got {new_actual} "
|
|
f"before line {line_number}."
|
|
)
|
|
|
|
for line_number, line in enumerate(patch.splitlines(), start=1):
|
|
if line.startswith("@@"):
|
|
flush(line_number)
|
|
current = _parse_hunk_header(line, line_number)
|
|
continue
|
|
if current is None:
|
|
continue
|
|
if line.startswith("diff --git "):
|
|
flush(line_number)
|
|
current = None
|
|
continue
|
|
if line.startswith("\\"):
|
|
continue
|
|
if line.startswith(" "):
|
|
current["old_actual"] += 1
|
|
current["new_actual"] += 1
|
|
elif line.startswith("-"):
|
|
current["old_actual"] += 1
|
|
elif line.startswith("+"):
|
|
current["new_actual"] += 1
|
|
flush(len(patch.splitlines()) + 1)
|
|
|
|
|
|
def _parse_hunk_header(line: str, line_number: int) -> dict[str, int]:
|
|
match = re.match(
|
|
r"^@@ -(?P<old_start>\d+)(?:,(?P<old_count>\d+))? "
|
|
r"\+(?P<new_start>\d+)(?:,(?P<new_count>\d+))? @@",
|
|
line,
|
|
)
|
|
if not match:
|
|
raise PipelineError(
|
|
f"Patch validation failed: malformed hunk header at line {line_number}."
|
|
)
|
|
old_count = int(match.group("old_count") or "1")
|
|
new_count = int(match.group("new_count") or "1")
|
|
return {
|
|
"line_number": line_number,
|
|
"old_expected": old_count,
|
|
"new_expected": new_count,
|
|
"old_actual": 0,
|
|
"new_actual": 0,
|
|
}
|
|
|
|
|
|
def _format_hunk_header(line: str, body: list[str], line_number: int) -> str:
|
|
match = re.match(
|
|
r"^@@ -(?P<old_start>\d+)(?:,(?P<old_count>\d+))? "
|
|
r"\+(?P<new_start>\d+)(?:,(?P<new_count>\d+))? @@(?P<section>.*)$",
|
|
line,
|
|
)
|
|
if not match:
|
|
raise PipelineError(
|
|
f"Patch validation failed: malformed hunk header at line {line_number}."
|
|
)
|
|
old_count = 0
|
|
new_count = 0
|
|
for body_line in body:
|
|
if body_line.startswith("\\"):
|
|
continue
|
|
if body_line.startswith(" "):
|
|
old_count += 1
|
|
new_count += 1
|
|
elif body_line.startswith("-"):
|
|
old_count += 1
|
|
elif body_line.startswith("+"):
|
|
new_count += 1
|
|
return (
|
|
f"@@ -{match.group('old_start')}{_format_count(old_count)} "
|
|
f"+{match.group('new_start')}{_format_count(new_count)} @@"
|
|
f"{match.group('section')}"
|
|
)
|
|
|
|
|
|
def _format_count(count: int) -> str:
|
|
return "" if count == 1 else f",{count}"
|
|
|
|
|
|
def _validate_file_states(patch: str, root: Path) -> None:
|
|
current_path: str | None = None
|
|
current_is_new = False
|
|
current_is_deleted = False
|
|
|
|
def flush() -> None:
|
|
if not current_path:
|
|
return
|
|
target = root / current_path
|
|
if current_is_new and target.exists():
|
|
raise PipelineError(
|
|
f"Patch validation failed: patch creates existing file `{current_path}`."
|
|
)
|
|
if current_is_deleted and not target.exists():
|
|
raise PipelineError(
|
|
f"Patch validation failed: patch deletes missing file `{current_path}`."
|
|
)
|
|
|
|
for line in patch.splitlines():
|
|
if line.startswith("diff --git "):
|
|
flush()
|
|
parts = line.split()
|
|
current_path = _strip_prefix(parts[3]) if len(parts) >= 4 else None
|
|
current_is_new = False
|
|
current_is_deleted = False
|
|
elif line.startswith("new file mode "):
|
|
current_is_new = True
|
|
elif line.startswith("deleted file mode "):
|
|
current_is_deleted = True
|
|
flush()
|
|
|
|
|
|
def _changed_line_count(patch: str) -> int:
|
|
count = 0
|
|
in_hunk = False
|
|
for line in patch.splitlines():
|
|
if line.startswith("diff --git "):
|
|
in_hunk = False
|
|
continue
|
|
if line.startswith("@@"):
|
|
in_hunk = True
|
|
continue
|
|
if not in_hunk or line.startswith("\\"):
|
|
continue
|
|
if line.startswith(("+", "-")):
|
|
count += 1
|
|
return count
|
|
|
|
|
|
def _validate_patch_path(
|
|
path_text: str,
|
|
root: Path,
|
|
scoped_roots: tuple[Path, ...],
|
|
forbidden_paths: tuple[str, ...],
|
|
) -> None:
|
|
path = Path(path_text)
|
|
if path.is_absolute() or ".." in path.parts:
|
|
raise PipelineError(f"Patch validation failed: unsafe path `{path_text}`.")
|
|
normalized = path.as_posix()
|
|
for forbidden in forbidden_paths:
|
|
forbidden_path = forbidden.strip("/\\")
|
|
if normalized == forbidden_path or normalized.startswith(forbidden_path + "/"):
|
|
raise PipelineError(f"Patch validation failed: forbidden path `{path_text}`.")
|
|
try:
|
|
resolved = resolve_inside_root(root, path, f"patch path '{path_text}'")
|
|
except SafetyError as exc:
|
|
raise PipelineError(f"Patch validation failed: {exc}") from exc
|
|
for scoped_root in scoped_roots:
|
|
try:
|
|
resolved.relative_to(scoped_root)
|
|
return
|
|
except ValueError:
|
|
continue
|
|
scopes = ", ".join(item.relative_to(root).as_posix() for item in scoped_roots)
|
|
raise PipelineError(
|
|
f"Patch validation failed: path `{path_text}` is outside scoped paths: {scopes}."
|
|
)
|
|
|
|
|
|
def _normalize_update_path(path_text: str) -> str:
|
|
normalized = path_text.replace("\\", "/").strip()
|
|
if normalized.startswith(("a/", "b/")):
|
|
normalized = normalized[2:]
|
|
return normalized
|
|
|
|
|
|
def _diff_for_file(path: str, old_text: str, new_text: str, exists: bool) -> list[str]:
|
|
old_lines = old_text.splitlines()
|
|
new_lines = new_text.splitlines()
|
|
from_file = f"a/{path}" if exists else "/dev/null"
|
|
to_file = f"b/{path}"
|
|
diff_lines = list(
|
|
difflib.unified_diff(
|
|
old_lines,
|
|
new_lines,
|
|
fromfile=from_file,
|
|
tofile=to_file,
|
|
lineterm="",
|
|
)
|
|
)
|
|
if not diff_lines:
|
|
return []
|
|
header = [f"diff --git a/{path} b/{path}"]
|
|
if not exists:
|
|
header.append("new file mode 100644")
|
|
return [*header, *diff_lines]
|
|
|
|
|
|
def _strip_prefix(path_text: str) -> str:
|
|
path = path_text.strip()
|
|
if path.startswith(("a/", "b/")):
|
|
return path[2:]
|
|
return path
|