"""Plan schema: steps with ids, dependencies, optional fallback paths with validation.""" from typing import Any from pydantic import BaseModel, Field, field_validator, model_validator class PlanStep(BaseModel): """ Single step in a plan. Validation: - id and description must be non-empty """ id: str = Field(..., min_length=1, description="Step identifier") description: str = Field(..., min_length=1, description="What to do") dependencies: list[str] = Field(default_factory=list, description="Step ids that must complete first") tool_name: str | None = Field(default=None, description="Optional tool to invoke") tool_args: dict[str, Any] = Field(default_factory=dict, description="Optional tool arguments") metadata: dict[str, Any] = Field(default_factory=dict, description="Extra data") @field_validator("id", "description") @classmethod def validate_non_whitespace(cls, v: str) -> str: """Validate string fields are not just whitespace.""" if not v.strip(): raise ValueError("Field cannot be empty or whitespace") return v class Plan(BaseModel): """ Plan graph: steps and optional fallback paths. Validation: - No duplicate step IDs - All dependency references must be valid step IDs - All fallback path references must be valid step IDs - No circular dependencies """ steps: list[PlanStep] = Field(default_factory=list, description="Ordered steps") fallback_paths: list[list[str]] = Field(default_factory=list, description="Alternative step sequences") metadata: dict[str, Any] = Field(default_factory=dict, description="Plan-level metadata") @model_validator(mode="after") def validate_plan(self) -> "Plan": """Validate the entire plan structure.""" step_ids = {s.id for s in self.steps} # Check for duplicate step IDs if len(step_ids) != len(self.steps): seen = set() duplicates = [] for s in self.steps: if s.id in seen: duplicates.append(s.id) seen.add(s.id) raise ValueError(f"Duplicate step IDs: {duplicates}") # Check all dependency references are valid for step in self.steps: invalid_deps = [d for d in step.dependencies if d not in step_ids] if invalid_deps: raise ValueError( f"Step '{step.id}' has invalid dependencies: {invalid_deps}" ) # Check all fallback path references are valid for i, path in enumerate(self.fallback_paths): invalid_refs = [ref for ref in path if ref not in step_ids] if invalid_refs: raise ValueError( f"Fallback path {i} has invalid step references: {invalid_refs}" ) # Check for circular dependencies cycles = self._find_cycles() if cycles: raise ValueError(f"Circular dependencies detected: {cycles}") return self def _find_cycles(self) -> list[list[str]]: """Find circular dependencies in the plan graph using DFS.""" # Build adjacency list graph: dict[str, list[str]] = {s.id: list(s.dependencies) for s in self.steps} cycles = [] visited = set() rec_stack = set() path = [] def dfs(node: str) -> bool: visited.add(node) rec_stack.add(node) path.append(node) for neighbor in graph.get(node, []): if neighbor not in visited: if dfs(neighbor): return True elif neighbor in rec_stack: # Found cycle cycle_start = path.index(neighbor) cycles.append(path[cycle_start:] + [neighbor]) return True path.pop() rec_stack.remove(node) return False for step_id in graph: if step_id not in visited: dfs(step_id) return cycles def step_ids(self) -> list[str]: """Return step ids in order.""" return [s.id for s in self.steps] def get_step(self, step_id: str) -> PlanStep | None: """Get a step by ID.""" for step in self.steps: if step.id == step_id: return step return None def get_dependencies(self, step_id: str) -> list[PlanStep]: """Get all dependency steps for a given step.""" step = self.get_step(step_id) if not step: return [] return [s for s in self.steps if s.id in step.dependencies] def get_dependents(self, step_id: str) -> list[PlanStep]: """Get all steps that depend on the given step.""" return [s for s in self.steps if step_id in s.dependencies] def topological_order(self) -> list[str]: """ Return step IDs in topological order (dependencies first). Uses Kahn's algorithm. """ # Build in-degree map in_degree = {s.id: len(s.dependencies) for s in self.steps} # Build adjacency list (reverse direction for dependents) dependents: dict[str, list[str]] = {s.id: [] for s in self.steps} for step in self.steps: for dep in step.dependencies: if dep in dependents: dependents[dep].append(step.id) # Start with nodes that have no dependencies queue = [sid for sid, deg in in_degree.items() if deg == 0] result = [] while queue: node = queue.pop(0) result.append(node) for dependent in dependents.get(node, []): in_degree[dependent] -= 1 if in_degree[dependent] == 0: queue.append(dependent) # Add any remaining nodes (would indicate cycles, but we validate above) remaining = [sid for sid in in_degree if sid not in result] result.extend(remaining) return result def to_dict(self) -> dict[str, Any]: """Serialize for message payload / state.""" return { "steps": [s.model_dump() for s in self.steps], "fallback_paths": self.fallback_paths, "metadata": self.metadata, } @classmethod def from_dict(cls, d: dict[str, Any]) -> "Plan": """Deserialize from dict. Steps may be dicts (validated) or PlanStep instances.""" if not isinstance(d, dict): raise TypeError(f"Plan.from_dict expects dict, got {type(d).__name__}") raw_steps = d.get("steps", []) steps: list[PlanStep] = [] for s in raw_steps: if isinstance(s, PlanStep): steps.append(s) elif isinstance(s, dict): steps.append(PlanStep.model_validate(s)) else: raise TypeError(f"Step must be dict or PlanStep, got {type(s).__name__}") return cls( steps=steps, fallback_paths=d.get("fallback_paths", []), metadata=d.get("metadata", {}), )