diff --git a/docs/concepts/custom_memory_storage.mdx b/docs/concepts/custom_memory_storage.mdx new file mode 100644 index 0000000000..f483d02fff --- /dev/null +++ b/docs/concepts/custom_memory_storage.mdx @@ -0,0 +1,151 @@ +# Custom Memory Storage + +CrewAI supports custom memory storage implementations for different memory types. You can provide your own storage implementation by extending the `Storage` interface and passing it to the memory instances or through the `memory_config` parameter. + +## Implementing a Custom Storage + +To create a custom storage implementation, you need to extend the `Storage` interface and implement the required methods: + +```python +from typing import Any, Dict, List +from crewai.memory.storage.interface import Storage + +class CustomStorage(Storage): + """Custom storage implementation.""" + + def __init__(self): + # Initialize your storage backend + self.data = [] + + def save(self, value: Any, metadata: Dict[str, Any]) -> None: + """Save a value with metadata to the storage.""" + # Implement your save logic + self.data.append({"value": value, "metadata": metadata}) + + def search( + self, query: str, limit: int = 3, score_threshold: float = 0.35 + ) -> List[Any]: + """Search for values in the storage.""" + # Implement your search logic + return [{"context": item["value"], "metadata": item["metadata"]} for item in self.data] + + def reset(self) -> None: + """Reset the storage.""" + # Implement your reset logic + self.data = [] +``` + +## Using Custom Storage + +There are two ways to provide custom storage implementations to CrewAI: + +### 1. Pass Custom Storage to Memory Instances + +You can create memory instances with custom storage and pass them to the Crew: + +```python +from crewai import Crew, Agent +from crewai.memory.short_term.short_term_memory import ShortTermMemory +from crewai.memory.long_term.long_term_memory import LongTermMemory +from crewai.memory.entity.entity_memory import EntityMemory +from crewai.memory.user.user_memory import UserMemory + +# Create custom storage instances +short_term_storage = CustomStorage() +long_term_storage = CustomStorage() +entity_storage = CustomStorage() +user_storage = CustomStorage() + +# Create memory instances with custom storage +short_term_memory = ShortTermMemory(storage=short_term_storage) +long_term_memory = LongTermMemory(storage=long_term_storage) +entity_memory = EntityMemory(storage=entity_storage) +user_memory = UserMemory(storage=user_storage) + +# Create a crew with custom memory instances +crew = Crew( + agents=[Agent(role="researcher", goal="research", backstory="I am a researcher")], + memory=True, + short_term_memory=short_term_memory, + long_term_memory=long_term_memory, + entity_memory=entity_memory, + memory_config={"user_memory": user_memory}, +) +``` + +### 2. Pass Custom Storage through Memory Config + +You can also provide custom storage implementations through the `memory_config` parameter: + +```python +from crewai import Crew, Agent + +# Create a crew with custom storage in memory_config +crew = Crew( + agents=[Agent(role="researcher", goal="research", backstory="I am a researcher")], + memory=True, + memory_config={ + "storage": { + "short_term": CustomStorage(), + "long_term": CustomStorage(), + "entity": CustomStorage(), + "user": CustomStorage(), + } + }, +) +``` + +## Example: Redis Storage + +Here's an example of a custom storage implementation using Redis: + +```python +import json +import redis +from typing import Any, Dict, List +from crewai.memory.storage.interface import Storage + +class RedisStorage(Storage): + """Redis-based storage implementation.""" + + def __init__(self, redis_url="redis://localhost:6379/0", prefix="crewai"): + self.redis = redis.from_url(redis_url) + self.prefix = prefix + + def save(self, value: Any, metadata: Dict[str, Any]) -> None: + """Save a value with metadata to Redis.""" + key = f"{self.prefix}:{len(self.redis.keys(f'{self.prefix}:*'))}" + data = {"value": value, "metadata": metadata} + self.redis.set(key, json.dumps(data)) + + def search( + self, query: str, limit: int = 3, score_threshold: float = 0.35 + ) -> List[Any]: + """Search for values in Redis.""" + # This is a simple implementation that returns all values + # In a real implementation, you would use Redis search capabilities + results = [] + for key in self.redis.keys(f"{self.prefix}:*"): + data = json.loads(self.redis.get(key)) + results.append({"context": data["value"], "metadata": data["metadata"]}) + if len(results) >= limit: + break + return results + + def reset(self) -> None: + """Reset the Redis storage.""" + for key in self.redis.keys(f"{self.prefix}:*"): + self.redis.delete(key) +``` + +## Benefits of Custom Storage + +Using custom storage implementations allows you to: + +1. Store memory data in external databases or services +2. Implement custom search algorithms +3. Share memory between different crews or applications +4. Persist memory across application restarts +5. Implement custom memory retention policies + +By extending the `Storage` interface, you can integrate CrewAI with any storage backend that suits your needs. diff --git a/src/crewai/crew.py b/src/crewai/crew.py index 9cecfed3a2..cd31d38291 100644 --- a/src/crewai/crew.py +++ b/src/crewai/crew.py @@ -262,8 +262,19 @@ def set_private_attrs(self) -> "Crew": def create_crew_memory(self) -> "Crew": """Set private attributes.""" if self.memory: + from crewai.memory.storage.rag_storage import RAGStorage + + # Create default storage instances for each memory type if needed + long_term_storage = RAGStorage(type="long_term", crew=self, embedder_config=self.embedder) + short_term_storage = RAGStorage(type="short_term", crew=self, embedder_config=self.embedder) + entity_storage = RAGStorage(type="entity", crew=self, embedder_config=self.embedder) + self._long_term_memory = ( - self.long_term_memory if self.long_term_memory else LongTermMemory() + self.long_term_memory if self.long_term_memory else LongTermMemory( + crew=self, + embedder_config=self.embedder, + storage=long_term_storage + ) ) self._short_term_memory = ( self.short_term_memory @@ -271,12 +282,17 @@ def create_crew_memory(self) -> "Crew": else ShortTermMemory( crew=self, embedder_config=self.embedder, + storage=short_term_storage ) ) self._entity_memory = ( self.entity_memory if self.entity_memory - else EntityMemory(crew=self, embedder_config=self.embedder) + else EntityMemory( + crew=self, + embedder_config=self.embedder, + storage=entity_storage + ) ) if ( self.memory_config and "user_memory" in self.memory_config diff --git a/src/crewai/memory/contextual/contextual_memory.py b/src/crewai/memory/contextual/contextual_memory.py index cdb9cf836d..7ec514d5b3 100644 --- a/src/crewai/memory/contextual/contextual_memory.py +++ b/src/crewai/memory/contextual/contextual_memory.py @@ -47,7 +47,7 @@ def _fetch_stm_context(self, query) -> str: stm_results = self.stm.search(query) formatted_results = "\n".join( [ - f"- {result['memory'] if self.memory_provider == 'mem0' else result['context']}" + f"- {result.get('memory', result.get('context', ''))}" for result in stm_results ] ) @@ -58,7 +58,7 @@ def _fetch_ltm_context(self, task) -> Optional[str]: Fetches historical data or insights from LTM that are relevant to the task's description and expected_output, formatted as bullet points. """ - ltm_results = self.ltm.search(task, latest_n=2) + ltm_results = self.ltm.search(query=task, limit=2) if not ltm_results: return None @@ -80,9 +80,9 @@ def _fetch_entity_context(self, query) -> str: em_results = self.em.search(query) formatted_results = "\n".join( [ - f"- {result['memory'] if self.memory_provider == 'mem0' else result['context']}" + f"- {result.get('memory', result.get('context', ''))}" for result in em_results - ] # type: ignore # Invalid index type "str" for "str"; expected type "SupportsIndex | slice" + ] ) return f"Entities:\n{formatted_results}" if em_results else "" @@ -99,6 +99,6 @@ def _fetch_user_context(self, query: str) -> str: return "" formatted_memories = "\n".join( - f"- {result['memory']}" for result in user_memories + f"- {result.get('memory', result.get('context', ''))}" for result in user_memories ) return f"User memories/preferences:\n{formatted_memories}" diff --git a/src/crewai/memory/entity/entity_memory.py b/src/crewai/memory/entity/entity_memory.py index 264b641032..28f9bd4ba9 100644 --- a/src/crewai/memory/entity/entity_memory.py +++ b/src/crewai/memory/entity/entity_memory.py @@ -1,4 +1,4 @@ -from typing import Optional +from typing import Any, Dict, Optional from pydantic import PrivateAttr @@ -17,47 +17,71 @@ class EntityMemory(Memory): _memory_provider: Optional[str] = PrivateAttr() def __init__(self, crew=None, embedder_config=None, storage=None, path=None): + memory_provider = None + memory_config = None + if crew and hasattr(crew, "memory_config") and crew.memory_config is not None: - memory_provider = crew.memory_config.get("provider") - else: - memory_provider = None - - if memory_provider == "mem0": + memory_config = crew.memory_config + memory_provider = memory_config.get("provider") + + # If no storage is provided, try to create one + if storage is None: try: - from crewai.memory.storage.mem0_storage import Mem0Storage - except ImportError: - raise ImportError( - "Mem0 is not installed. Please install it with `pip install mem0ai`." + # Try to select storage using helper method + storage = self._select_storage( + storage=storage, + memory_config=memory_config, + storage_type="entity", + crew=crew, + path=path, + default_storage_factory=lambda path, crew: RAGStorage( + type="entities", + allow_reset=True, + crew=crew, + embedder_config=embedder_config, + path=path, + ) ) - storage = Mem0Storage(type="entities", crew=crew) - else: - storage = ( - storage - if storage - else RAGStorage( + except ValueError: + # Fallback to default storage + storage = RAGStorage( type="entities", allow_reset=True, - embedder_config=embedder_config, crew=crew, + embedder_config=embedder_config, path=path, ) - ) - - super().__init__(storage=storage) - self._memory_provider = memory_provider + + # Initialize with parameters + super().__init__( + storage=storage, + embedder_config=embedder_config, + memory_provider=memory_provider + ) + - def save(self, item: EntityMemoryItem) -> None: # type: ignore # BUG?: Signature of "save" incompatible with supertype "Memory" - """Saves an entity item into the SQLite storage.""" - if self._memory_provider == "mem0": - data = f""" - Remember details about the following entity: - Name: {item.name} - Type: {item.type} - Entity Description: {item.description} - """ + def save( + self, + value: Any, + metadata: Optional[Dict[str, Any]] = None, + agent: Optional[str] = None, + ) -> None: + """Saves an entity item or value into the storage.""" + if isinstance(value, EntityMemoryItem): + item = value + if self.memory_provider == "mem0": + data = f""" + Remember details about the following entity: + Name: {item.name} + Type: {item.type} + Entity Description: {item.description} + """ + else: + data = f"{item.name}({item.type}): {item.description}" + super().save(data, item.metadata) else: - data = f"{item.name}({item.type}): {item.description}" - super().save(data, item.metadata) + # Handle regular value and metadata + super().save(value, metadata, agent) def reset(self) -> None: try: diff --git a/src/crewai/memory/long_term/long_term_memory.py b/src/crewai/memory/long_term/long_term_memory.py index 94aac3a977..bfc16d32bc 100644 --- a/src/crewai/memory/long_term/long_term_memory.py +++ b/src/crewai/memory/long_term/long_term_memory.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List +from typing import Any, Dict, List, Optional from crewai.memory.long_term.long_term_memory_item import LongTermMemoryItem from crewai.memory.memory import Memory @@ -14,23 +14,77 @@ class LongTermMemory(Memory): LongTermMemoryItem instances. """ - def __init__(self, storage=None, path=None): - if not storage: - storage = LTMSQLiteStorage(db_path=path) if path else LTMSQLiteStorage() - super().__init__(storage=storage) - - def save(self, item: LongTermMemoryItem) -> None: # type: ignore # BUG?: Signature of "save" incompatible with supertype "Memory" - metadata = item.metadata - metadata.update({"agent": item.agent, "expected_output": item.expected_output}) - self.storage.save( # type: ignore # BUG?: Unexpected keyword argument "task_description","score","datetime" for "save" of "Storage" - task_description=item.task, - score=metadata["quality"], - metadata=metadata, - datetime=item.datetime, + def __init__(self, crew=None, embedder_config=None, storage=None, path=None): + memory_provider = None + memory_config = None + + if crew and hasattr(crew, "memory_config") and crew.memory_config is not None: + memory_config = crew.memory_config + memory_provider = memory_config.get("provider") + + # Initialize with basic parameters + super().__init__( + storage=storage, + embedder_config=embedder_config, + memory_provider=memory_provider ) + + try: + # Try to select storage using helper method + self.storage = self._select_storage( + storage=storage, + memory_config=memory_config, + storage_type="long_term", + crew=crew, + path=path, + default_storage_factory=lambda path, crew: LTMSQLiteStorage(db_path=path) if path else LTMSQLiteStorage() + ) + except ValueError: + # Fallback to default storage + self.storage = LTMSQLiteStorage(db_path=path) if path else LTMSQLiteStorage() + + def save( + self, + value: Any, + metadata: Optional[Dict[str, Any]] = None, + agent: Optional[str] = None, + ) -> None: + """Saves a value into the memory.""" + if isinstance(value, LongTermMemoryItem): + item = value + item_metadata = item.metadata or {} + item_metadata.update({"agent": item.agent, "expected_output": item.expected_output}) + + # Handle special storage types like Mem0Storage + if hasattr(self.storage, "save") and callable(getattr(self.storage, "save")) and hasattr(self.storage.save, "__code__") and "task_description" in self.storage.save.__code__.co_varnames: + self.storage.save( + task_description=item.task, + score=item_metadata.get("quality", 0), + metadata=item_metadata, + datetime=item.datetime, + ) + else: + # Use standard storage interface + self.storage.save(item.task, item_metadata) + else: + # Handle regular value and metadata + super().save(value, metadata, agent) - def search(self, task: str, latest_n: int = 3) -> List[Dict[str, Any]]: # type: ignore # signature of "search" incompatible with supertype "Memory" - return self.storage.load(task, latest_n) # type: ignore # BUG?: "Storage" has no attribute "load" + def search( + self, + query: str, + limit: int = 3, + score_threshold: float = 0.35, + ) -> List[Any]: + """Search for values in the memory.""" + # Try to use the standard storage interface first + if hasattr(self.storage, "search") and callable(getattr(self.storage, "search")): + return self.storage.search(query=query, limit=limit, score_threshold=score_threshold) + # Fall back to load method for backward compatibility + elif hasattr(self.storage, "load") and callable(getattr(self.storage, "load")): + return self.storage.load(query, limit) + else: + raise AttributeError("Storage does not implement search or load method") def reset(self) -> None: self.storage.reset() diff --git a/src/crewai/memory/memory.py b/src/crewai/memory/memory.py index 9a362a5125..bc9f273ae9 100644 --- a/src/crewai/memory/memory.py +++ b/src/crewai/memory/memory.py @@ -1,20 +1,62 @@ -from typing import Any, Dict, List, Optional +from typing import Any, Callable, Dict, Generic, List, Optional, TypeVar, cast -from pydantic import BaseModel +from pydantic import BaseModel, ConfigDict, Field +from crewai.memory.storage.interface import SearchResult, Storage -class Memory(BaseModel): +T = TypeVar('T', bound=Storage) + +class Memory(BaseModel, Generic[T]): """ Base class for memory, now supporting agent tags and generic metadata. """ + + model_config = ConfigDict(arbitrary_types_allowed=True) embedder_config: Optional[Dict[str, Any]] = None + storage: T + memory_provider: Optional[str] = Field(default=None, exclude=True) - storage: Any - - def __init__(self, storage: Any, **data: Any): + def __init__(self, storage: T, **data: Any): super().__init__(storage=storage, **data) + def _select_storage( + self, + storage: Optional[T] = None, + memory_config: Optional[Dict[str, Any]] = None, + storage_type: str = "", + crew=None, + path: Optional[str] = None, + default_storage_factory: Optional[Callable] = None, + ) -> T: + """Helper method to select the appropriate storage based on configuration""" + # Use the provided storage if available + if storage: + return storage + + # Use storage from memory_config if available + if memory_config and "storage" in memory_config: + storage_config = memory_config.get("storage", {}) + if storage_type in storage_config and storage_config[storage_type]: + return cast(T, storage_config[storage_type]) + + # Use Mem0Storage if specified in memory_config + if memory_config and memory_config.get("provider") == "mem0": + try: + from crewai.memory.storage.mem0_storage import Mem0Storage + return cast(T, Mem0Storage(type=storage_type, crew=crew)) + except ImportError: + raise ImportError( + "Mem0 is not installed. Please install it with `pip install mem0ai`." + ) + + # Use default storage if provided + if default_storage_factory: + return cast(T, default_storage_factory(path=path, crew=crew)) + + # Fallback to empty storage + raise ValueError(f"No storage available for {storage_type}") + def save( self, value: Any, @@ -25,14 +67,19 @@ def save( if agent: metadata["agent"] = agent - self.storage.save(value, metadata) + if self.storage: + self.storage.save(value, metadata) + else: + raise ValueError("Storage is not initialized") def search( self, query: str, limit: int = 3, score_threshold: float = 0.35, - ) -> List[Any]: + ) -> List[SearchResult]: + if not self.storage: + raise ValueError("Storage is not initialized") return self.storage.search( query=query, limit=limit, score_threshold=score_threshold ) diff --git a/src/crewai/memory/short_term/short_term_memory.py b/src/crewai/memory/short_term/short_term_memory.py index b7581f4002..460ee9e29c 100644 --- a/src/crewai/memory/short_term/short_term_memory.py +++ b/src/crewai/memory/short_term/short_term_memory.py @@ -19,32 +19,43 @@ class ShortTermMemory(Memory): _memory_provider: Optional[str] = PrivateAttr() def __init__(self, crew=None, embedder_config=None, storage=None, path=None): + memory_provider = None + memory_config = None + if crew and hasattr(crew, "memory_config") and crew.memory_config is not None: - memory_provider = crew.memory_config.get("provider") - else: - memory_provider = None - - if memory_provider == "mem0": - try: - from crewai.memory.storage.mem0_storage import Mem0Storage - except ImportError: - raise ImportError( - "Mem0 is not installed. Please install it with `pip install mem0ai`." - ) - storage = Mem0Storage(type="short_term", crew=crew) - else: - storage = ( - storage - if storage - else RAGStorage( + memory_config = crew.memory_config + memory_provider = memory_config.get("provider") + + # Initialize with basic parameters + super().__init__( + storage=storage, + embedder_config=embedder_config, + memory_provider=memory_provider + ) + + try: + # Try to select storage using helper method + self.storage = self._select_storage( + storage=storage, + memory_config=memory_config, + storage_type="short_term", + crew=crew, + path=path, + default_storage_factory=lambda path, crew: RAGStorage( type="short_term", - embedder_config=embedder_config, crew=crew, + embedder_config=embedder_config, path=path, ) ) - super().__init__(storage=storage) - self._memory_provider = memory_provider + except ValueError: + # Fallback to default storage + self.storage = RAGStorage( + type="short_term", + crew=crew, + embedder_config=embedder_config, + path=path, + ) def save( self, @@ -53,7 +64,7 @@ def save( agent: Optional[str] = None, ) -> None: item = ShortTermMemoryItem(data=value, metadata=metadata, agent=agent) - if self._memory_provider == "mem0": + if self.memory_provider == "mem0": item.data = f"Remember the following insights from Agent run: {item.data}" super().save(value=item.data, metadata=item.metadata, agent=item.agent) diff --git a/src/crewai/memory/storage/base_rag_storage.py b/src/crewai/memory/storage/base_rag_storage.py index 4ab9acb991..1c5772e3d8 100644 --- a/src/crewai/memory/storage/base_rag_storage.py +++ b/src/crewai/memory/storage/base_rag_storage.py @@ -1,8 +1,10 @@ from abc import ABC, abstractmethod from typing import Any, Dict, List, Optional +from crewai.memory.storage.interface import SearchResult, Storage -class BaseRAGStorage(ABC): + +class BaseRAGStorage(Storage[Any], ABC): """ Base class for RAG-based Storage implementations. """ @@ -44,9 +46,8 @@ def search( self, query: str, limit: int = 3, - filter: Optional[dict] = None, score_threshold: float = 0.35, - ) -> List[Any]: + ) -> List[SearchResult]: """Search for entries in the storage.""" pass diff --git a/src/crewai/memory/storage/interface.py b/src/crewai/memory/storage/interface.py index 8bec9a14f2..15d92f1a14 100644 --- a/src/crewai/memory/storage/interface.py +++ b/src/crewai/memory/storage/interface.py @@ -1,16 +1,39 @@ -from typing import Any, Dict, List +from abc import ABC, abstractmethod +from typing import Any, ClassVar, Dict, Generic, List, Protocol, TypeVar, TypedDict, runtime_checkable +from pydantic import BaseModel, ConfigDict -class Storage: +class SearchResult(TypedDict, total=False): + """Type definition for search results""" + context: str + metadata: Dict[str, Any] + score: float + memory: str # For Mem0Storage compatibility + +T = TypeVar('T') + +@runtime_checkable +class StorageProtocol(Protocol): + """Protocol defining the storage interface""" + def save(self, value: Any, metadata: Dict[str, Any]) -> None: ... + def search(self, query: str, limit: int, score_threshold: float) -> List[Any]: ... + def reset(self) -> None: ... + +class Storage(ABC, Generic[T]): """Abstract base class defining the storage interface""" + + model_config: ClassVar[ConfigDict] = ConfigDict(arbitrary_types_allowed=True) + @abstractmethod def save(self, value: Any, metadata: Dict[str, Any]) -> None: pass + @abstractmethod def search( self, query: str, limit: int, score_threshold: float - ) -> Dict[str, Any] | List[Any]: - return {} + ) -> List[SearchResult]: + pass + @abstractmethod def reset(self) -> None: pass diff --git a/src/crewai/memory/storage/mem0_storage.py b/src/crewai/memory/storage/mem0_storage.py index be889afff2..da3ee65d31 100644 --- a/src/crewai/memory/storage/mem0_storage.py +++ b/src/crewai/memory/storage/mem0_storage.py @@ -111,3 +111,9 @@ def _get_agent_name(self): agents = [self._sanitize_role(agent.role) for agent in agents] agents = "_".join(agents) return agents + + def reset(self) -> None: + """Reset the storage by clearing all memories.""" + # Mem0 doesn't have a direct reset method, but we can implement + # this in the future if needed. For now, we'll just pass. + pass diff --git a/src/crewai/memory/storage/rag_storage.py b/src/crewai/memory/storage/rag_storage.py index fd4c77838c..aebb3ebf56 100644 --- a/src/crewai/memory/storage/rag_storage.py +++ b/src/crewai/memory/storage/rag_storage.py @@ -9,6 +9,7 @@ from chromadb.api import ClientAPI from crewai.memory.storage.base_rag_storage import BaseRAGStorage +from crewai.memory.storage.interface import SearchResult from crewai.utilities import EmbeddingConfigurator from crewai.utilities.constants import MAX_FILE_NAME_LENGTH from crewai.utilities.paths import db_storage_path @@ -37,7 +38,7 @@ class RAGStorage(BaseRAGStorage): search efficiency. """ - app: ClientAPI | None = None + app: Optional[ClientAPI] = None def __init__( self, type, allow_reset=True, embedder_config=None, crew=None, path=None @@ -112,9 +113,8 @@ def search( self, query: str, limit: int = 3, - filter: Optional[dict] = None, score_threshold: float = 0.35, - ) -> List[Any]: + ) -> List[SearchResult]: if not hasattr(self, "app"): self._initialize_app() @@ -124,8 +124,7 @@ def search( results = [] for i in range(len(response["ids"][0])): - result = { - "id": response["ids"][0][i], + result: SearchResult = { "metadata": response["metadatas"][0][i], "context": response["documents"][0][i], "score": response["distances"][0][i], @@ -138,7 +137,7 @@ def search( logging.error(f"Error during {self.type} search: {str(e)}") return [] - def _generate_embedding(self, text: str, metadata: Dict[str, Any]) -> None: # type: ignore + def _generate_embedding(self, text: str, metadata: Optional[Dict[str, Any]] = None) -> Any: if not hasattr(self, "app") or not hasattr(self, "collection"): self._initialize_app() diff --git a/src/crewai/memory/user/user_memory.py b/src/crewai/memory/user/user_memory.py index 24e5fe0353..6cde49cba0 100644 --- a/src/crewai/memory/user/user_memory.py +++ b/src/crewai/memory/user/user_memory.py @@ -11,15 +11,46 @@ class UserMemory(Memory): MemoryItem instances. """ - def __init__(self, crew=None): + def __init__(self, crew=None, embedder_config=None, storage=None, path=None, **kwargs): + memory_provider = None + memory_config = None + + if crew and hasattr(crew, "memory_config") and crew.memory_config is not None: + memory_config = crew.memory_config + memory_provider = memory_config.get("provider") + + # Initialize with basic parameters + super().__init__( + storage=storage, + embedder_config=embedder_config, + memory_provider=memory_provider + ) + try: - from crewai.memory.storage.mem0_storage import Mem0Storage - except ImportError: - raise ImportError( - "Mem0 is not installed. Please install it with `pip install mem0ai`." + # Try to select storage using helper method + from crewai.memory.storage.rag_storage import RAGStorage + self.storage = self._select_storage( + storage=storage, + memory_config=memory_config, + storage_type="user", + crew=crew, + path=path, + default_storage_factory=lambda path, crew: RAGStorage( + type="user", + crew=crew, + embedder_config=embedder_config, + path=path, + ) + ) + except ValueError: + # Fallback to default storage + from crewai.memory.storage.rag_storage import RAGStorage + self.storage = RAGStorage( + type="user", + crew=crew, + embedder_config=embedder_config, + path=path, ) - storage = Mem0Storage(type="user", crew=crew) - super().__init__(storage) def save( self, @@ -43,3 +74,9 @@ def search( score_threshold=score_threshold, ) return results + + def reset(self) -> None: + try: + self.storage.reset() + except Exception as e: + raise Exception(f"An error occurred while resetting the user memory: {e}") diff --git a/tests/memory/long_term_memory_test.py b/tests/memory/long_term_memory_test.py index 3639054e3f..12f968b4c5 100644 --- a/tests/memory/long_term_memory_test.py +++ b/tests/memory/long_term_memory_test.py @@ -7,7 +7,31 @@ @pytest.fixture def long_term_memory(): """Fixture to create a LongTermMemory instance""" - return LongTermMemory() + # Create a mock storage for testing + from crewai.memory.storage.interface import Storage + + class MockStorage(Storage): + def __init__(self): + self.data = [] + + def save(self, value, metadata): + self.data.append({"value": value, "metadata": metadata}) + + def search(self, query, limit=3, score_threshold=0.35): + return [ + { + "context": item["value"], + "metadata": item["metadata"], + "score": 0.5, + "datetime": item["metadata"].get("datetime", "test_datetime") + } + for item in self.data + ] + + def reset(self): + self.data = [] + + return LongTermMemory(storage=MockStorage()) def test_save_and_search(long_term_memory): @@ -20,7 +44,7 @@ def test_save_and_search(long_term_memory): metadata={"task": "test_task", "quality": 0.5}, ) long_term_memory.save(memory) - find = long_term_memory.search("test_task", latest_n=5)[0] + find = long_term_memory.search(query="test_task", limit=5)[0] assert find["score"] == 0.5 assert find["datetime"] == "test_datetime" assert find["metadata"]["agent"] == "test_agent" diff --git a/tests/memory/short_term_memory_test.py b/tests/memory/short_term_memory_test.py index 6cde2a044b..5ad00f47e0 100644 --- a/tests/memory/short_term_memory_test.py +++ b/tests/memory/short_term_memory_test.py @@ -12,6 +12,8 @@ @pytest.fixture def short_term_memory(): """Fixture to create a ShortTermMemory instance""" + from crewai.memory.storage.rag_storage import RAGStorage + agent = Agent( role="Researcher", goal="Search relevant data and provide results", @@ -25,7 +27,10 @@ def short_term_memory(): expected_output="A list of relevant URLs based on the search query.", agent=agent, ) - return ShortTermMemory(crew=Crew(agents=[agent], tasks=[task])) + + storage = RAGStorage(type="short_term") + crew = Crew(agents=[agent], tasks=[task]) + return ShortTermMemory(storage=storage, crew=crew) def test_save_and_search(short_term_memory): diff --git a/tests/memory/test_custom_storage.py b/tests/memory/test_custom_storage.py new file mode 100644 index 0000000000..dd725b0382 --- /dev/null +++ b/tests/memory/test_custom_storage.py @@ -0,0 +1,211 @@ +from typing import Any, Dict, List + +import pytest + +from crewai.agent import Agent +from crewai.crew import Crew +from crewai.memory.entity.entity_memory import EntityMemory +from crewai.memory.long_term.long_term_memory import LongTermMemory +from crewai.memory.short_term.short_term_memory import ShortTermMemory +from crewai.memory.storage.interface import SearchResult, Storage +from crewai.memory.user.user_memory import UserMemory + + +class CustomStorage(Storage[Any]): + """Custom storage implementation for testing.""" + + def __init__(self): + self.data = [] + + def save(self, value: Any, metadata: Dict[str, Any]) -> None: + self.data.append({"value": value, "metadata": metadata}) + + def search( + self, query: str, limit: int = 3, score_threshold: float = 0.35 + ) -> List[SearchResult]: + return [{"context": item["value"], "metadata": item["metadata"], "score": 0.9} for item in self.data] + + def reset(self) -> None: + self.data = [] + + +def test_custom_storage_with_short_term_memory(): + """Test that custom storage works with short term memory.""" + custom_storage = CustomStorage() + memory = ShortTermMemory(storage=custom_storage) + + memory.save("test value", {"key": "value"}) + results = memory.search("test") + + assert len(results) > 0 + assert results[0]["context"] == "test value" + assert results[0]["metadata"]["key"] == "value" + + +def test_custom_storage_with_long_term_memory(): + """Test that custom storage works with long term memory.""" + custom_storage = CustomStorage() + memory = LongTermMemory(storage=custom_storage) + + memory.save("test value", {"key": "value"}) + results = memory.search("test") + + assert len(results) > 0 + assert results[0]["context"] == "test value" + assert results[0]["metadata"]["key"] == "value" + + +def test_custom_storage_with_entity_memory(): + """Test that custom storage works with entity memory.""" + custom_storage = CustomStorage() + memory = EntityMemory(storage=custom_storage) + + memory.save("test value", {"key": "value"}) + results = memory.search("test") + + assert len(results) > 0 + assert results[0]["context"] == "test value" + assert results[0]["metadata"]["key"] == "value" + + +def test_custom_storage_with_user_memory(): + """Test that custom storage works with user memory.""" + custom_storage = CustomStorage() + memory = UserMemory(storage=custom_storage) + + memory.save("test value", {"key": "value"}) + results = memory.search("test") + + assert len(results) > 0 + # UserMemory prepends "Remember the details about the user: " to the value + assert "test value" in results[0]["context"] + assert results[0]["metadata"]["key"] == "value" + + +def test_custom_storage_with_crew(): + """Test that custom storage works with crew.""" + short_term_storage = CustomStorage() + long_term_storage = CustomStorage() + entity_storage = CustomStorage() + user_storage = CustomStorage() + + # Create memory instances with custom storage + short_term_memory = ShortTermMemory(storage=short_term_storage) + long_term_memory = LongTermMemory(storage=long_term_storage) + entity_memory = EntityMemory(storage=entity_storage) + user_memory = UserMemory(storage=user_storage) + + # Create a crew with custom memory instances + crew = Crew( + agents=[Agent(role="test", goal="test", backstory="test")], + memory=True, + short_term_memory=short_term_memory, + long_term_memory=long_term_memory, + entity_memory=entity_memory, + memory_config={"user_memory": user_memory}, + ) + + # Test that the crew has the custom memory instances + assert crew._short_term_memory.storage == short_term_storage + assert crew._long_term_memory.storage == long_term_storage + assert crew._entity_memory.storage == entity_storage + assert crew._user_memory.storage == user_storage + + +def test_custom_storage_with_memory_config(): + """Test that custom storage works with memory_config.""" + short_term_storage = CustomStorage() + long_term_memory = LongTermMemory(storage=CustomStorage()) + entity_memory = EntityMemory(storage=CustomStorage()) + user_memory = UserMemory(storage=CustomStorage()) + + # Create a crew with custom storage in memory_config + crew = Crew( + agents=[Agent(role="test", goal="test", backstory="test")], + memory=True, + short_term_memory=ShortTermMemory(storage=short_term_storage), + long_term_memory=long_term_memory, + entity_memory=entity_memory, + memory_config={ + "user_memory": user_memory + }, + ) + + # Test that the crew has the custom storage instances + assert crew._short_term_memory.storage == short_term_storage + assert crew._long_term_memory == long_term_memory + assert crew._entity_memory == entity_memory + assert crew._user_memory == user_memory + + +def test_custom_storage_error_handling(): + """Test error handling with custom storage.""" + # Test exception propagation + class ErrorStorage(Storage[Any]): + """Storage implementation that raises exceptions.""" + def __init__(self): + self.data = [] + + def save(self, value: Any, metadata: Dict[str, Any]) -> None: + raise ValueError("Save error") + + def search( + self, query: str, limit: int = 3, score_threshold: float = 0.35 + ) -> List[SearchResult]: + raise ValueError("Search error") + + def reset(self) -> None: + raise ValueError("Reset error") + + storage = ErrorStorage() + memory = ShortTermMemory(storage=storage) + + with pytest.raises(ValueError, match="Save error"): + memory.save("test", {}) + + with pytest.raises(ValueError, match="Search error"): + memory.search("test") + + with pytest.raises(Exception, match="An error occurred while resetting the short-term memory: Reset error"): + memory.reset() + + +def test_custom_storage_edge_cases(): + """Test edge cases with custom storage.""" + class EdgeCaseStorage(Storage[Any]): + """Storage implementation for testing edge cases.""" + def __init__(self): + self.data = [] + + def save(self, value: Any, metadata: Dict[str, Any]) -> None: + self.data.append({"value": value, "metadata": metadata}) + + def search( + self, query: str, limit: int = 3, score_threshold: float = 0.35 + ) -> List[SearchResult]: + return [{"context": item["value"], "metadata": item["metadata"], "score": 0.5} for item in self.data] + + def reset(self) -> None: + self.data = [] + + storage = EdgeCaseStorage() + memory = ShortTermMemory(storage=storage) + + # Test empty query + memory.save("test value", {"key": "value"}) + results = memory.search("") + assert len(results) > 0 + + # Test very large metadata + large_metadata = {"key" + str(i): "value" * 100 for i in range(100)} + memory.save("test value", large_metadata) + results = memory.search("test") + assert len(results) > 0 + assert results[1]["metadata"] == large_metadata + + # Test unicode and special characters + unicode_value = "测试值 with special chars: !@#$%^&*()" + memory.save(unicode_value, {"key": "value"}) + results = memory.search("测试") + assert len(results) > 0 + assert unicode_value in results[2]["context"]