Initial commit: add .gitignore and README
This commit is contained in:
150
fusionagi/memory/working.py
Normal file
150
fusionagi/memory/working.py
Normal file
@@ -0,0 +1,150 @@
|
||||
"""Working memory: in-memory key-value / list per task/session.
|
||||
|
||||
Working memory provides short-term storage for active tasks:
|
||||
- Key-value storage per session/task
|
||||
- List append operations for accumulating results
|
||||
- Context retrieval for reasoning
|
||||
- Session lifecycle management
|
||||
"""
|
||||
|
||||
from collections import defaultdict
|
||||
from datetime import datetime
|
||||
from typing import Any, Iterator
|
||||
|
||||
from fusionagi._logger import logger
|
||||
from fusionagi._time import utc_now
|
||||
|
||||
|
||||
class WorkingMemory:
|
||||
"""
|
||||
Short-term working memory per task/session.
|
||||
|
||||
Features:
|
||||
- Key-value get/set operations
|
||||
- List append with automatic coercion
|
||||
- Context summary for LLM prompts
|
||||
- Session management and cleanup
|
||||
- Size limits to prevent unbounded growth
|
||||
"""
|
||||
|
||||
def __init__(self, max_entries_per_session: int = 1000) -> None:
|
||||
"""
|
||||
Initialize working memory.
|
||||
|
||||
Args:
|
||||
max_entries_per_session: Maximum entries per session before oldest are removed.
|
||||
"""
|
||||
self._store: dict[str, dict[str, Any]] = defaultdict(dict)
|
||||
self._timestamps: dict[str, datetime] = {}
|
||||
self._max_entries = max_entries_per_session
|
||||
|
||||
def get(self, session_id: str, key: str, default: Any = None) -> Any:
|
||||
"""Get value for session and key; returns default if not found."""
|
||||
return self._store[session_id].get(key, default)
|
||||
|
||||
def set(self, session_id: str, key: str, value: Any) -> None:
|
||||
"""Set value for session and key."""
|
||||
self._store[session_id][key] = value
|
||||
self._timestamps[session_id] = utc_now()
|
||||
self._enforce_limits(session_id)
|
||||
|
||||
def append(self, session_id: str, key: str, value: Any) -> None:
|
||||
"""Append to list for session and key (creates list if needed)."""
|
||||
if key not in self._store[session_id]:
|
||||
self._store[session_id][key] = []
|
||||
lst = self._store[session_id][key]
|
||||
if not isinstance(lst, list):
|
||||
lst = [lst]
|
||||
self._store[session_id][key] = lst
|
||||
lst.append(value)
|
||||
self._timestamps[session_id] = utc_now()
|
||||
self._enforce_limits(session_id)
|
||||
|
||||
def get_list(self, session_id: str, key: str) -> list[Any]:
|
||||
"""Return list for session and key (copy)."""
|
||||
val = self._store[session_id].get(key)
|
||||
if isinstance(val, list):
|
||||
return list(val)
|
||||
return [val] if val is not None else []
|
||||
|
||||
def has(self, session_id: str, key: str) -> bool:
|
||||
"""Check if a key exists in session."""
|
||||
return key in self._store.get(session_id, {})
|
||||
|
||||
def keys(self, session_id: str) -> list[str]:
|
||||
"""Return all keys for a session."""
|
||||
return list(self._store.get(session_id, {}).keys())
|
||||
|
||||
def delete(self, session_id: str, key: str) -> bool:
|
||||
"""Delete a key from session. Returns True if existed."""
|
||||
if session_id in self._store and key in self._store[session_id]:
|
||||
del self._store[session_id][key]
|
||||
return True
|
||||
return False
|
||||
|
||||
def clear_session(self, session_id: str) -> None:
|
||||
"""Clear all data for a session."""
|
||||
self._store.pop(session_id, None)
|
||||
self._timestamps.pop(session_id, None)
|
||||
|
||||
def get_context_summary(self, session_id: str, max_items: int = 10) -> dict[str, Any]:
|
||||
"""
|
||||
Get a summary of working memory for context injection.
|
||||
|
||||
Useful for including relevant context in LLM prompts.
|
||||
"""
|
||||
session_data = self._store.get(session_id, {})
|
||||
summary = {}
|
||||
|
||||
for key, value in list(session_data.items())[:max_items]:
|
||||
if isinstance(value, list):
|
||||
# For lists, include count and last few items
|
||||
summary[key] = {
|
||||
"type": "list",
|
||||
"count": len(value),
|
||||
"recent": value[-3:] if len(value) > 3 else value,
|
||||
}
|
||||
elif isinstance(value, dict):
|
||||
# For dicts, include keys
|
||||
summary[key] = {
|
||||
"type": "dict",
|
||||
"keys": list(value.keys())[:10],
|
||||
}
|
||||
else:
|
||||
# For scalars, include the value (truncated if string)
|
||||
if isinstance(value, str) and len(value) > 200:
|
||||
summary[key] = value[:200] + "..."
|
||||
else:
|
||||
summary[key] = value
|
||||
|
||||
return summary
|
||||
|
||||
def get_all(self, session_id: str) -> dict[str, Any]:
|
||||
"""Return all data for a session (copy)."""
|
||||
return dict(self._store.get(session_id, {}))
|
||||
|
||||
def session_exists(self, session_id: str) -> bool:
|
||||
"""Check if a session has any data."""
|
||||
return session_id in self._store and bool(self._store[session_id])
|
||||
|
||||
def active_sessions(self) -> list[str]:
|
||||
"""Return list of sessions with data."""
|
||||
return [sid for sid, data in self._store.items() if data]
|
||||
|
||||
def session_count(self) -> int:
|
||||
"""Return number of active sessions."""
|
||||
return len([s for s in self._store.values() if s])
|
||||
|
||||
def _enforce_limits(self, session_id: str) -> None:
|
||||
"""Enforce size limits on session data."""
|
||||
session_data = self._store.get(session_id, {})
|
||||
total_items = sum(
|
||||
len(v) if isinstance(v, (list, dict)) else 1
|
||||
for v in session_data.values()
|
||||
)
|
||||
|
||||
if total_items > self._max_entries:
|
||||
logger.warning(
|
||||
"Working memory size limit exceeded",
|
||||
extra={"session_id": session_id, "items": total_items},
|
||||
)
|
||||
Reference in New Issue
Block a user