diff --git a/app/alembic/versions/20240820182032_d3f532773223_changes_for_implementation_of_.py b/app/alembic/versions/20240820182032_d3f532773223_changes_for_implementation_of_.py index d0d84f11..5a5dff6e 100644 --- a/app/alembic/versions/20240820182032_d3f532773223_changes_for_implementation_of_.py +++ b/app/alembic/versions/20240820182032_d3f532773223_changes_for_implementation_of_.py @@ -95,4 +95,4 @@ def downgrade() -> None: # Drop the ENUM type if it is no longer used message_status_enum.drop(op.get_bind(), checkfirst=False) - # ### end Alembic commands ### + # ### end Alembic commands ### \ No newline at end of file diff --git a/app/alembic/versions/20240823164559_05069444feee_project_id_to_string_anddelete_col.py b/app/alembic/versions/20240823164559_05069444feee_project_id_to_string_anddelete_col.py new file mode 100644 index 00000000..76dd8e41 --- /dev/null +++ b/app/alembic/versions/20240823164559_05069444feee_project_id_to_string_anddelete_col.py @@ -0,0 +1,50 @@ +"""project id to string anddelete col + +Revision ID: 20240823164559_05069444feee +Revises: 20240820182032_d3f532773223 +Create Date: 2024-08-23 16:45:59.991109 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision: str = '20240823164559_05069444feee' +down_revision: Union[str, None] = '20240820182032_d3f532773223' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.alter_column('projects', 'id', + existing_type=sa.INTEGER(), + type_=sa.Text(), + existing_nullable=False) + op.drop_constraint('projects_directory_key', 'projects', type_='unique') + op.drop_column('projects', 'directory') + op.drop_column('projects', 'is_default') + op.drop_column('projects', 'project_name') + op.drop_constraint('check_status', 'projects', type_='check') + op.create_check_constraint('check_status', 'projects', + "status IN ('submitted', 'cloned', 'parsed', 'ready', 'error')") + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.add_column('projects', sa.Column('project_name', sa.TEXT(), autoincrement=False, nullable=True)) + op.add_column('projects', sa.Column('is_default', sa.BOOLEAN(), autoincrement=False, nullable=True)) + op.add_column('projects', sa.Column('directory', sa.TEXT(), autoincrement=False, nullable=True)) + op.create_unique_constraint('projects_directory_key', 'projects', ['directory']) + op.alter_column('projects', 'id', + existing_type=sa.Text(), + type_=sa.INTEGER(), + existing_nullable=False) + op.drop_constraint('check_status', 'projects', type_='check') + op.create_check_constraint('check_status', 'projects', + "status IN ('created', 'ready', 'error')") + # ### end Alembic commands ### diff --git a/app/core/config.py b/app/core/config.py new file mode 100644 index 00000000..b5616a17 --- /dev/null +++ b/app/core/config.py @@ -0,0 +1,21 @@ +from dotenv import load_dotenv +import os + +load_dotenv() + +class ConfigProvider: + def __init__(self): + self.neo4j_config = { + "uri": os.getenv("NEO4J_URI"), + "username": os.getenv("NEO4J_USERNAME"), + "password": os.getenv("NEO4J_PASSWORD"), + } + self.github_key = os.getenv("GITHUB_PRIVATE_KEY") + + def get_neo4j_config(self): + return self.neo4j_config + + def get_github_key(self): + return self.github_key + +config_provider = ConfigProvider() \ No newline at end of file diff --git a/app/core/mongo_manager.py b/app/core/mongo_manager.py new file mode 100644 index 00000000..8c1a92bd --- /dev/null +++ b/app/core/mongo_manager.py @@ -0,0 +1,131 @@ +import os +import logging +from pymongo import MongoClient +from pymongo.errors import ConnectionFailure, OperationFailure +from typing import Optional +import certifi + +class MongoManager: + _instance = None + _client: Optional[MongoClient] = None + _db = None + + @classmethod + def get_instance(cls): + if cls._instance is None: + cls._instance = cls() + return cls._instance + + def __init__(self): + if self._instance is not None: + raise RuntimeError("Use get_instance() to get the MongoManager instance") + self._connect() + + def _connect(self): + if self._client is None: + try: + mongodb_uri = os.environ.get("MONGO_URI") + if not mongodb_uri: + raise ValueError("MONGO_URI environment variable is not set") + + self._client = MongoClient( + mongodb_uri, + maxPoolSize=50, + waitQueueTimeoutMS=2500, + tlsCAFile=certifi.where() # Use the certifi package to locate the CA bundle + ) + + db_name = os.environ.get("MONGODB_DB_NAME") + if not db_name: + raise ValueError("MONGODB_DB_NAME environment variable is not set") + + self._db = self._client[db_name] + + # Verify the connection and database + self.verify_connection() + + except (ConnectionFailure, ValueError) as e: + logging.error(f"Failed to connect to MongoDB: {str(e)}") + raise + + def verify_connection(self): + try: + # Ping the server to check the connection + self._client.admin.command('ping') + + # List all collections to verify database access + self._db.list_collection_names() + + logging.info("Successfully connected to MongoDB and verified database access") + except OperationFailure as e: + logging.error(f"Failed to verify MongoDB connection: {str(e)}") + raise + + def get_collection(self, collection_name: str): + self._connect() # Ensure connection is established + return self._db[collection_name] + + def put(self, collection_name: str, document_id: str, data: dict): + try: + collection = self.get_collection(collection_name) + result = collection.update_one( + {"_id": document_id}, + {"$set": data}, + upsert=True + ) + logging.info(f"Document {'updated' if result.modified_count else 'inserted'} in {collection_name}") + return result + except Exception as e: + logging.error(f"Failed to put document in {collection_name}: {str(e)}") + raise + + def get(self, collection_name: str, document_id: str): + try: + collection = self.get_collection(collection_name) + document = collection.find_one({"_id": document_id}) + if document: + logging.info(f"Document retrieved from {collection_name}") + else: + logging.info(f"Document not found in {collection_name}") + return document + except Exception as e: + logging.error(f"Failed to get document from {collection_name}: {str(e)}") + raise + + def delete(self, collection_name: str, document_id: str): + try: + collection = self.get_collection(collection_name) + result = collection.delete_one({"_id": document_id}) + if result.deleted_count: + logging.info(f"Document deleted from {collection_name}") + else: + logging.info(f"Document not found in {collection_name}") + return result + except Exception as e: + logging.error(f"Failed to delete document from {collection_name}: {str(e)}") + raise + + def close(self): + if self._client: + self._client.close() + self._client = None + self._db = None + logging.info("MongoDB connection closed") + + def reconnect(self): + self.close() + self._connect() + logging.info("Reconnected to MongoDB") + + @classmethod + def close_connection(cls): + if cls._instance: + cls._instance.close() + cls._instance = None + logging.info("MongoDB connection closed and instance reset") + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + pass # Don't close the connection here \ No newline at end of file diff --git a/app/main.py b/app/main.py index fd93ad42..4f324766 100644 --- a/app/main.py +++ b/app/main.py @@ -1,6 +1,9 @@ import os import logging +# Configure logging +logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') + from dotenv import load_dotenv from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware @@ -12,6 +15,13 @@ from app.modules.utils.dummy_setup import DummyDataSetup +from app.modules.utils.firebase_setup import FirebaseSetup +from app.modules.parsing.graph_construction.parsing_router import router as parsing_router +from app.modules.auth.auth_router import auth_router +from app.modules.key_management.secret_manager import router as secret_manager_router + +from app.core.mongo_manager import MongoManager + class MainApp: def __init__(self): load_dotenv(override=True) @@ -19,8 +29,21 @@ def __init__(self): self.setup_cors() self.initialize_database() self.check_and_set_env_vars() - self.setup_data() + if os.getenv("isDevelopmentMode") == "enabled": + self.setup_data() + else: + FirebaseSetup.firebase_init() self.include_routers() + self.verify_mongodb_connection() + + def verify_mongodb_connection(self): + try: + mongo_manager = MongoManager.get_instance() + mongo_manager.verify_connection() + logging.info("MongoDB connection verified successfully") + except Exception as e: + logging.error(f"Failed to verify MongoDB connection: {str(e)}") + raise def setup_cors(self): origins = ["*"] @@ -56,6 +79,9 @@ def setup_data(self): def include_routers(self): self.app.include_router(user_router, prefix="/api/v1", tags=["User"]) self.app.include_router(conversations_router, prefix="/api/v1", tags=["Conversations"]) + self.app.include_router(parsing_router, prefix="/api/v1", tags=["Parsing"]) + self.app.include_router(auth_router, prefix="/api/v1", tags=["Auth"]) + self.app.include_router(secret_manager_router, prefix="/api/v1", tags=["Secret Manager"]) def add_health_check(self): @@ -67,7 +93,10 @@ def run(self): self.add_health_check() return self.app - # Create an instance of MainApp and run it main_app = MainApp() app = main_app.run() + +@app.on_event("shutdown") +def shutdown_event(): + MongoManager.close_connection() \ No newline at end of file diff --git a/app/modules/auth/auth_router.py b/app/modules/auth/auth_router.py new file mode 100644 index 00000000..0caf15db --- /dev/null +++ b/app/modules/auth/auth_router.py @@ -0,0 +1,65 @@ +import json +import os + +from datetime import datetime +from dotenv import load_dotenv + +from fastapi import Depends, Request +from fastapi.responses import JSONResponse, Response +from sqlalchemy.orm import Session +from app.core.database import get_db +from app.modules.auth.auth_service import auth_handler +from app.modules.users.user_service import UserService + +from app.modules.utils.APIRouter import APIRouter + +from .auth_schema import LoginRequest +from app.modules.users.user_schema import CreateUser + +import logging + +auth_router = APIRouter() +load_dotenv(override=True) + +class AuthAPI: + @auth_router.post("/login") + async def login(login_request: LoginRequest): + email, password = login_request.email, login_request.password + + try: + res = auth_handler.login(email=email, password=password) + id_token = res.get("idToken") + return JSONResponse(content={"token": id_token}, status_code=200) + except Exception as e: + return JSONResponse( + content={"error": f"ERROR: {str(e)}"}, status_code=400 + ) + + @auth_router.post("/signup") + async def signup(request: Request, db: Session = Depends(get_db)): + body = json.loads(await request.body()) + uid = body["uid"] + user_service = UserService(db) + user = user_service.get_user_by_uid(uid) + if user: + message, error = user_service.update_last_login(uid) + if error: + return Response(content=message, status_code=400) + else: + return Response(content=json.dumps({"uid": uid}), status_code=200) + else: + first_login = datetime.utcnow() + user = CreateUser( + uid=uid, + email=body["email"], + display_name=body["displayName"], + email_verified=body["emailVerified"], + created_at=first_login, + last_login_at=first_login, + provider_info=body["providerData"][0], + provider_username=body["providerUsername"] + ) + uid, message, error = user_service.create_user(user) + if error: + return Response(content=message, status_code=400) + return Response(content=json.dumps({"uid": uid}), status_code=201) \ No newline at end of file diff --git a/app/modules/auth/auth_schema.py b/app/modules/auth/auth_schema.py new file mode 100644 index 00000000..dc357543 --- /dev/null +++ b/app/modules/auth/auth_schema.py @@ -0,0 +1,6 @@ +from pydantic import BaseModel + + +class LoginRequest(BaseModel): + email: str + password: str \ No newline at end of file diff --git a/app/modules/auth/auth_service.py b/app/modules/auth/auth_service.py new file mode 100644 index 00000000..63c6770c --- /dev/null +++ b/app/modules/auth/auth_service.py @@ -0,0 +1,71 @@ +import logging +import os +import requests +from firebase_admin import auth +from fastapi import Request, HTTPException, Response, Depends, status +from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer + +class AuthService: + def login(self, email, password): + log_prefix = "AuthService::login:" + identity_tool_kit_id = os.getenv("GOOGLE_IDENTITY_TOOL_KIT_KEY") + identity_url = f"https://identitytoolkit.googleapis.com/v1/accounts:signInWithPassword?key={identity_tool_kit_id}" + + user_auth_response = requests.post( + url=identity_url, + json={ + "email": email, + "password": password, + "returnSecureToken": True, + }, + ) + + try: + user_auth_response.raise_for_status() + return user_auth_response.json() + except Exception as e: + logging.exception(f"{log_prefix} {str(e)}") + raise Exception(user_auth_response.json()) + + def signup(self, email, password, name): + user = auth.create_user( + email=email, password=password, display_name=name + ) + return user + + @classmethod + @staticmethod + async def check_auth( + request: Request, + res: Response, + credential: HTTPAuthorizationCredentials = Depends( + HTTPBearer(auto_error=False) + ), +): + # Check if the application is in debug mode + if os.getenv("isDevelopmentMode") == "enabled" and credential is None: + request.state.user = {"user_id": os.getenv("defaultUsername")} + return {"user_id":os.getenv("defaultUsername")} + else: + if credential is None: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Bearer authentication is needed", + headers={"WWW-Authenticate": 'Bearer realm="auth_required"'}, + ) + try: + decoded_token = auth.verify_id_token(credential.credentials) + request.state.user = decoded_token + except Exception as err: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail=f"Invalid authentication from Firebase. {err}", + headers={"WWW-Authenticate": 'Bearer error="invalid_token"'}, + ) + res.headers["WWW-Authenticate"] = 'Bearer realm="auth_required"' + return decoded_token + + + + +auth_handler = AuthService() diff --git a/app/modules/conversations/conversation/conversation_controller.py b/app/modules/conversations/conversation/conversation_controller.py index 83b475ad..c3372dd5 100644 --- a/app/modules/conversations/conversation/conversation_controller.py +++ b/app/modules/conversations/conversation/conversation_controller.py @@ -66,4 +66,4 @@ async def stop_generation(self, conversation_id: str) -> dict: except ConversationNotFoundError as e: raise HTTPException(status_code=404, detail=str(e)) except ConversationServiceError as e: - raise HTTPException(status_code=500, detail=str(e)) \ No newline at end of file + raise HTTPException(status_code=500, detail=str(e)) diff --git a/app/modules/conversations/conversation/conversation_model.py b/app/modules/conversations/conversation/conversation_model.py index 08a52b86..22f84c4d 100644 --- a/app/modules/conversations/conversation/conversation_model.py +++ b/app/modules/conversations/conversation/conversation_model.py @@ -24,4 +24,4 @@ class Conversation(Base): # Conversation relationships Conversation.user = relationship("User", back_populates="conversations") -Conversation.messages = relationship("Message", back_populates="conversation", cascade="all, delete-orphan") \ No newline at end of file +Conversation.messages = relationship("Message", back_populates="conversation", cascade="all, delete-orphan") diff --git a/app/modules/conversations/conversation/conversation_service.py b/app/modules/conversations/conversation/conversation_service.py index 53820b91..4b8c2373 100644 --- a/app/modules/conversations/conversation/conversation_service.py +++ b/app/modules/conversations/conversation/conversation_service.py @@ -291,4 +291,4 @@ async def stop_generation(self, conversation_id: str) -> dict: # Implement the logic to stop the generation process # This might involve setting a flag in the orchestrator or cancelling an ongoing task logger.info(f"Attempting to stop generation for conversation {conversation_id}") - return {"status": "success", "message": "Generation stop request received"} \ No newline at end of file + return {"status": "success", "message": "Generation stop request received"} diff --git a/app/modules/conversations/conversations_router.py b/app/modules/conversations/conversations_router.py index f5ae4a16..b4a72e0f 100644 --- a/app/modules/conversations/conversations_router.py +++ b/app/modules/conversations/conversations_router.py @@ -2,6 +2,7 @@ from fastapi import APIRouter, Depends, Query from fastapi.responses import StreamingResponse from app.core.database import get_db +from app.modules.auth.auth_service import AuthService from app.modules.conversations.conversation.conversation_controller import ConversationController from .conversation.conversation_schema import ( CreateConversationRequest, @@ -18,7 +19,8 @@ class ConversationAPI: @router.post("/conversations/", response_model=CreateConversationResponse) async def create_conversation( conversation: CreateConversationRequest, - db: Session = Depends(get_db) + db: Session = Depends(get_db), + user=Depends(AuthService.check_auth) ): controller = ConversationController(db) return await controller.create_conversation(conversation) @@ -27,7 +29,8 @@ async def create_conversation( @router.get("/conversations/{conversation_id}/info/", response_model=ConversationInfoResponse) async def get_conversation_info( conversation_id: str, - db: Session = Depends(get_db) + db: Session = Depends(get_db), + user=Depends(AuthService.check_auth) ): controller = ConversationController(db) return await controller.get_conversation_info(conversation_id) @@ -38,7 +41,8 @@ async def get_conversation_messages( conversation_id: str, start: int = Query(0, ge=0), limit: int = Query(10, ge=1), - db: Session = Depends(get_db) + db: Session = Depends(get_db), + user=Depends(AuthService.check_auth) ): controller = ConversationController(db) return await controller.get_conversation_messages(conversation_id, start, limit) @@ -49,7 +53,8 @@ async def post_message( conversation_id: str, message: MessageRequest, user_id: str, - db: Session = Depends(get_db) + db: Session = Depends(get_db), + user=Depends(AuthService.check_auth) ): controller = ConversationController(db) @@ -62,7 +67,8 @@ async def post_message( @router.post("/conversations/{conversation_id}/regenerate/", response_model=MessageResponse) async def regenerate_last_message( conversation_id: str, - db: Session = Depends(get_db) + db: Session = Depends(get_db), + user=Depends(AuthService.check_auth) ): controller = ConversationController(db) return StreamingResponse(controller.regenerate_last_message(conversation_id), media_type="text/event-stream") @@ -71,7 +77,8 @@ async def regenerate_last_message( @router.delete("/conversations/{conversation_id}/", response_model=dict) async def delete_conversation( conversation_id: str, - db: Session = Depends(get_db) + db: Session = Depends(get_db), + user=Depends(AuthService.check_auth) ): controller = ConversationController(db) return await controller.delete_conversation(conversation_id) @@ -80,7 +87,8 @@ async def delete_conversation( @router.post("/conversations/{conversation_id}/stop/", response_model=dict) async def stop_generation( conversation_id: str, - db: Session = Depends(get_db) + db: Session = Depends(get_db), + user=Depends(AuthService.check_auth) ): controller = ConversationController(db) return await controller.stop_generation(conversation_id) \ No newline at end of file diff --git a/app/modules/conversations/message/message_service.py b/app/modules/conversations/message/message_service.py index d5aa9a59..fa1f064d 100644 --- a/app/modules/conversations/message/message_service.py +++ b/app/modules/conversations/message/message_service.py @@ -88,4 +88,4 @@ def _sync_mark_message_archived(self, message_id: str): raise MessageNotFoundError(f"Message with id {message_id} not found.") except SQLAlchemyError as e: self.db.rollback() - raise \ No newline at end of file + raise diff --git a/app/modules/github/github_service.py b/app/modules/github/github_service.py new file mode 100644 index 00000000..da08ece9 --- /dev/null +++ b/app/modules/github/github_service.py @@ -0,0 +1,84 @@ +import base64 +import logging +import os + +import requests +from fastapi import HTTPException +from github import Github +from github.Auth import AppAuth +from sqlalchemy.orm import Session + +from app.modules.projects.projects_schema import ProjectStatusEnum +from app.modules.projects.projects_service import ProjectService +from app.core.config import config_provider +logger = logging.getLogger(__name__) + +class GithubService: + + # Start Generation Here + def __init__(self, db: Session): + self.project_manager = ProjectService(db) + + @staticmethod + def get_github_repo_details(repo_name): + private_key = "-----BEGIN RSA PRIVATE KEY-----\n" + config_provider.get_github_key() + "\n-----END RSA PRIVATE KEY-----\n" + app_id = os.environ["GITHUB_APP_ID"] + auth = AppAuth(app_id=app_id, private_key=private_key) + jwt = auth.create_jwt() + owner = repo_name.split('/')[0] + repo = repo_name.split('/')[1] + url = f"https://api.github.com/repos/{owner}/{repo}/installation" + headers = { + "Accept": "application/vnd.github+json", + "Authorization": f"Bearer {jwt}", + "X-GitHub-Api-Version": "2022-11-28" + } + + return requests.get(url, headers=headers), auth, owner + + @staticmethod + def check_is_commit_added(repo_details, project_details, branch_name): + branch = repo_details.get_branch(branch_name) + latest_commit_sha = branch.commit.sha + if latest_commit_sha == project_details[3] and project_details[4] == ProjectStatusEnum.READY: + return False + else: + return True + + @staticmethod + def fetch_method_from_repo(node, db): + method_content = None + github = None + try: + project_id = node["project_id"] + project_manager = ProjectService(db) + repo_details = project_manager.get_repo_and_branch_name(project_id=project_id) + repo_name = repo_details[0] + branch_name = repo_details[1] + + file_path = node["id"].split(':')[0].lstrip('/') + start_line = node["start"] + end_line = node["end"] + + response, auth, _ = GithubService.get_github_repo_details( + repo_name + ) + + if response.status_code != 200: + raise HTTPException( + status_code=400, detail="Failed to get installation ID" + ) + + app_auth = auth.get_installation_auth(response.json()["id"]) + github = Github(auth=app_auth) + repo = github.get_repo(repo_name) + file_contents = repo.get_contents(file_path.replace("\\","/"), ref=branch_name) + decoded_content = base64.b64decode(file_contents.content).decode('utf-8') + lines = decoded_content.split('\n') + method_lines = lines[start_line - 1:end_line] + method_content = '\n'.join(method_lines) + + except Exception as e: + logger.error(f"An error occurred: {e}", exc_info=True) + + return method_content \ No newline at end of file diff --git a/app/modules/key_management/__init__.py b/app/modules/key_management/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/app/modules/key_management/secret_manager.py b/app/modules/key_management/secret_manager.py new file mode 100644 index 00000000..4438dd21 --- /dev/null +++ b/app/modules/key_management/secret_manager.py @@ -0,0 +1,114 @@ +import os +from typing import Literal + +from fastapi import Depends, HTTPException +from google.cloud import secretmanager +from app.modules.utils.APIRouter import APIRouter + +from app.modules.auth.auth_service import AuthService +from app.core.mongo_manager import MongoManager +from app.modules.key_management.secrets_schema import CreateSecretRequest, UpdateSecretRequest + +router = APIRouter() + +class SecretManager: + @staticmethod + def get_client_and_project(): + if os.getenv("isDevelopmentMode") == "disabled": + client = secretmanager.SecretManagerServiceClient() + project_id = os.environ.get("GCP_PROJECT") + else: + client = None + project_id = None + return client, project_id + + @router.post("/secrets") + def create_secret(request: CreateSecretRequest, user=Depends(AuthService.check_auth)): + customer_id = user["user_id"] + client, project_id = SecretManager.get_client_and_project() + + mongo_manager = MongoManager.get_instance() + mongo_manager.put( + "preferences", + customer_id, + {"provider": request.provider} + ) + + api_key = request.api_key + secret_id = SecretManager.get_secret_id(request.provider, customer_id) + parent = f"projects/{project_id}" + + secret = {"replication": {"automatic": {}}} + response = client.create_secret( + request={"parent": parent, "secret_id": secret_id, "secret": secret} + ) + + version = {"payload": {"data": api_key.encode("UTF-8")}} + client.add_secret_version( + request={"parent": response.name, "payload": version["payload"]} + ) + + return {"message": "Secret created successfully"} + + @staticmethod + def get_secret_id(provider: Literal["openai"], customer_id: str): + if provider == "openai": + secret_id = f"openai-api-key-{customer_id}" + else: + raise HTTPException(status_code=400, detail="Invalid provider") + return secret_id + + @router.get("/secrets/{provider}") + def get_secret_for_provider(provider: Literal["openai"], user=Depends(AuthService.check_auth)): + customer_id = user["user_id"] + return SecretManager.get_secret(provider, customer_id) + + @staticmethod + def get_secret(provider: Literal["openai"], customer_id: str): + client, project_id = SecretManager.get_client_and_project() + secret_id = SecretManager.get_secret_id(provider, customer_id) + name = f"projects/{project_id}/secrets/{secret_id}/versions/latest" + + try: + response = client.access_secret_version(request={"name": name}) + api_key = response.payload.data.decode("UTF-8") + return {"api_key": api_key} + except Exception: + raise HTTPException(status_code=404, detail="Secret not found") + + @router.put("/secrets/") + def update_secret(request: UpdateSecretRequest, user=Depends(AuthService.check_auth)): + customer_id = user["user_id"] + api_key = request.api_key + secret_id = SecretManager.get_secret_id(request.provider, customer_id) + client, project_id = SecretManager.get_client_and_project() + parent = f"projects/{project_id}/secrets/{secret_id}" + version = {"payload": {"data": api_key.encode("UTF-8")}} + client.add_secret_version( + request={"parent": parent, "payload": version["payload"]} + ) + mongo_manager = MongoManager.get_instance() + mongo_manager.put( + "preferences", + customer_id, + {"provider": request.provider} + ) + + return {"message": "Secret updated successfully"} + + @router.delete("/secrets/{provider}") + def delete_secret(provider: Literal["openai"], user=Depends(AuthService.check_auth)): + customer_id = user["user_id"] + secret_id = SecretManager.get_secret_id(provider, customer_id) + client, project_id = SecretManager.get_client_and_project() + name = f"projects/{project_id}/secrets/{secret_id}" + + try: + client.delete_secret(request={"name": name}) + mongo_manager = MongoManager.get_instance() + mongo_manager.delete("preferences", customer_id) + return {"message": "Secret deleted successfully"} + except Exception: + raise HTTPException(status_code=404, detail="Secret not found") + + \ No newline at end of file diff --git a/app/modules/key_management/secrets_schema.py b/app/modules/key_management/secrets_schema.py new file mode 100644 index 00000000..437709bf --- /dev/null +++ b/app/modules/key_management/secrets_schema.py @@ -0,0 +1,24 @@ +from pydantic import BaseModel, validator +from typing import Literal +import re + +def validate_openai_api_key_format(api_key: str) -> bool: + pattern = r"^sk-[a-zA-Z0-9]{48}$" + proj_pattern = r"^sk-proj-[a-zA-Z0-9]{48}$" + return bool(re.match(pattern, api_key)) or bool(re.match(proj_pattern, api_key)) + +class BaseSecretRequest(BaseModel): + api_key: str + provider: Literal["openai"] = "openai" + + @validator("api_key") + def api_key_format(cls, v: str) -> str: + if not validate_openai_api_key_format(v): + raise ValueError("Invalid OpenAI API key format") + return v + +class UpdateSecretRequest(BaseSecretRequest): + pass + +class CreateSecretRequest(BaseSecretRequest): + pass \ No newline at end of file diff --git a/app/modules/parsing/graph_construction/parsing_helper.py b/app/modules/parsing/graph_construction/parsing_helper.py new file mode 100644 index 00000000..62094a71 --- /dev/null +++ b/app/modules/parsing/graph_construction/parsing_helper.py @@ -0,0 +1,833 @@ +import math +import time +import warnings +from collections import Counter, defaultdict, namedtuple +from pathlib import Path + +from grep_ast import TreeContext, filename_to_lang +from pygments.lexers import guess_lexer_for_filename +from pygments.token import Token +from pygments.util import ClassNotFound +from tqdm import tqdm +from tree_sitter_languages import get_language, get_parser # noqa: E402 +import json +import logging +import os +import shutil +import tarfile +import requests +from fastapi import HTTPException +from git import Repo, GitCommandError +from uuid6 import uuid7 +from app.modules.projects.projects_schema import ProjectStatusEnum +from app.modules.projects.projects_service import ProjectService +import networkx as nx +from sqlalchemy.orm import Session + +# tree_sitter is throwing a FutureWarning +warnings.simplefilter("ignore", category=FutureWarning) +Tag = namedtuple("Tag", "rel_fname fname line end_line name kind type".split()) + +class ParsingServiceError(Exception): + """Base exception class for ParsingService errors.""" + +class ParsingFailedError(ParsingServiceError): + """Raised when a parsing fails.""" + +class ParseHelper: + def __init__(self, db_session: Session): + self.project_manager = ProjectService(db_session) + self.db = db_session + + def download_and_extract_tarball(self, repo, branch, target_dir, auth, repo_details, user_id): + try: + tarball_url = repo_details.get_archive_link("tarball", branch) + response = requests.get( + tarball_url, + stream=True, + headers={"Authorization": f"{auth.token}"}, + ) + response.raise_for_status() # Check for request errors + except requests.exceptions.RequestException as e: + logging.error(f"Error fetching tarball: {e}") + return e + + tarball_path = os.path.join(target_dir, f"{repo.full_name.replace('/', '-')}-{branch}.tar.gz") + try: + with open(tarball_path, "wb") as f: + for chunk in response.iter_content(chunk_size=8192): + f.write(chunk) + except IOError as e: + logging.error(f"Error writing tarball to file: {e}") + return e + + final_dir = os.path.join(target_dir, f"{repo.full_name.replace('/', '-')}-{branch}-{user_id}") + try: + with tarfile.open(tarball_path, "r:gz") as tar: + for member in tar.getmembers(): + member_path = os.path.join( + final_dir, + os.path.relpath(member.name, start=member.name.split("/")[0]), + ) + if member.isdir(): + os.makedirs(member_path, exist_ok=True) + else: + member_dir = os.path.dirname(member_path) + if not os.path.exists(member_dir): + os.makedirs(member_dir) + with open(member_path, "wb") as f: + if member.size > 0: + f.write(tar.extractfile(member).read()) + except (tarfile.TarError, IOError) as e: + logging.error(f"Error extracting tarball: {e}") + return e + + try: + os.remove(tarball_path) + except OSError as e: + logging.error(f"Error removing tarball: {e}") + return e + + return final_dir + + @staticmethod + def detect_repo_language(repo_dir): + lang_count = { + "c_sharp": 0, "c": 0, "cpp": 0, "elisp": 0, "elixir": 0, "elm": 0, + "go": 0, "java": 0, "javascript": 0, "ocaml": 0, "php": 0, "python": 0, + "ql": 0, "ruby": 0, "rust": 0, "typescript": 0, "other": 0 + } + total_chars = 0 + for root, _, files in os.walk(repo_dir): + for file in files: + file_path = os.path.join(root, file) + ext = os.path.splitext(file)[1].lower() + try: + with open(file_path, 'r', encoding='utf-8') as f: + content = f.read() + total_chars += len(content) + if ext == '.cs': + lang_count["c_sharp"] += 1 + elif ext == '.c': + lang_count["c"] += 1 + elif ext in ['.cpp', '.cxx', '.cc']: + lang_count["cpp"] += 1 + elif ext == '.el': + lang_count["elisp"] += 1 + elif ext == '.ex' or ext == '.exs': + lang_count["elixir"] += 1 + elif ext == '.elm': + lang_count["elm"] += 1 + elif ext == '.go': + lang_count["go"] += 1 + elif ext == '.java': + lang_count["java"] += 1 + elif ext in ['.js', '.jsx']: + lang_count["javascript"] += 1 + elif ext == '.ml' or ext == '.mli': + lang_count["ocaml"] += 1 + elif ext == '.php': + lang_count["php"] += 1 + elif ext == '.py': + lang_count["python"] += 1 + elif ext == '.ql': + lang_count["ql"] += 1 + elif ext == '.rb': + lang_count["ruby"] += 1 + elif ext == '.rs': + lang_count["rust"] += 1 + elif ext in ['.ts', '.tsx']: + lang_count["typescript"] += 1 + else: + lang_count["other"] += 1 + except (UnicodeDecodeError, FileNotFoundError): + continue + # Determine the predominant language based on counts + predominant_language = max(lang_count, key=lang_count.get) + return predominant_language if lang_count[predominant_language] > 0 else "other" + + + async def setup_project_directory( + self, repo, branch, auth, repo_details, user_id, project_id = None # Change type to str + ): + + if not project_id: + pid = str(uuid7()) + project_id = await self.project_manager.register_project( + f"{repo.full_name}", + branch, + user_id, + pid, + ) + + + if isinstance(repo_details, Repo): + extracted_dir = repo_details.working_tree_dir + try: + current_dir = os.getcwd() + os.chdir(extracted_dir) # Change to the cloned repo directory + repo_details.git.checkout(branch) + except GitCommandError as e: + logging.error(f"Error checking out branch: {e}") + raise HTTPException( + status_code=400, detail=f"Failed to checkout branch {branch}" + ) + finally: + os.chdir(current_dir) # Restore the original working directory + branch_details = repo_details.head.commit + latest_commit_sha = branch_details.hexsha + else: + + extracted_dir = self.download_and_extract_tarball( + repo, branch, os.getenv("PROJECT_PATH"), auth, repo_details, user_id + ) + branch_details = repo_details.get_branch(branch) + latest_commit_sha = branch_details.commit.sha + + repo_metadata = ParseHelper.extract_repository_metadata(repo_details) + repo_metadata["error_message"] = None + project_metadata = json.dumps(repo_metadata).encode("utf-8") + ProjectService.update_project(self.db, project_id, properties=project_metadata, commit_id=latest_commit_sha, status=ProjectStatusEnum.CLONED.value) + + return extracted_dir, project_id + + def extract_repository_metadata(repo): + if isinstance(repo, Repo): + metadata = ParseHelper.extract_local_repo_metadata(repo) + else: + metadata = ParseHelper.extract_remote_repo_metadata(repo) + return metadata + + def extract_local_repo_metadata(repo): + languages = ParseHelper.get_local_repo_languages(repo.working_tree_dir) + total_bytes = sum(languages.values()) + + metadata = { + "basic_info": { + "full_name": os.path.basename(repo.working_tree_dir), + "description": None, + "created_at": None, + "updated_at": None, + "default_branch": repo.head.ref.name, + }, + "metrics": { + "size": ParseHelper.get_directory_size(repo.working_tree_dir), + "stars": None, + "forks": None, + "watchers": None, + "open_issues": None, + }, + "languages": { + "breakdown": languages, + "total_bytes": total_bytes, + }, + "commit_info": { + "total_commits": len(list(repo.iter_commits())) + }, + "contributors": { + "count": len(list(repo.iter_commits('--all'))), + }, + "topics": [], + } + + return metadata + + def get_local_repo_languages(path): + total_bytes = 0 + python_bytes = 0 + + for dirpath, _, filenames in os.walk(path): + for filename in filenames: + file_extension = os.path.splitext(filename)[1] + file_path = os.path.join(dirpath, filename) + file_size = os.path.getsize(file_path) + total_bytes += file_size + if file_extension == '.py': + python_bytes += file_size + + languages = {} + if total_bytes > 0: + languages['Python'] = python_bytes + languages['Other'] = total_bytes - python_bytes + + return languages + + def extract_remote_repo_metadata(repo): + languages = repo.get_languages() + total_bytes = sum(languages.values()) + + metadata = { + "basic_info": { + "full_name": repo.full_name, + "description": repo.description, + "created_at": repo.created_at.isoformat(), + "updated_at": repo.updated_at.isoformat(), + "default_branch": repo.default_branch, + }, + "metrics": { + "size": repo.size, + "stars": repo.stargazers_count, + "forks": repo.forks_count, + "watchers": repo.watchers_count, + "open_issues": repo.open_issues_count, + }, + "languages": { + "breakdown": languages, + "total_bytes": total_bytes, + }, + "commit_info": { + "total_commits": repo.get_commits().totalCount + }, + "contributors": { + "count": repo.get_contributors().totalCount, + }, + "topics": repo.get_topics(), + } + + return metadata + + +class RepoMap: + + + warned_files = set() + + def __init__( + self, + map_tokens=1024, + root=None, + main_model=None, + io=None, + repo_content_prefix=None, + verbose=False, + max_context_window=None, + map_mul_no_files=8, + ): + self.io = io + self.verbose = verbose + + if not root: + root = os.getcwd() + self.root = root + + + + self.max_map_tokens = map_tokens + self.map_mul_no_files = map_mul_no_files + self.max_context_window = max_context_window + + + self.repo_content_prefix = repo_content_prefix + + def get_repo_map(self, chat_files, other_files, mentioned_fnames=None, mentioned_idents=None): + if self.max_map_tokens <= 0: + return + if not other_files: + return + if not mentioned_fnames: + mentioned_fnames = set() + if not mentioned_idents: + mentioned_idents = set() + + max_map_tokens = self.max_map_tokens + + # With no files in the chat, give a bigger view of the entire repo + padding = 4096 + if max_map_tokens and self.max_context_window: + target = min( + max_map_tokens * self.map_mul_no_files, + self.max_context_window - padding, + ) + else: + target = 0 + if not chat_files and self.max_context_window and target > 0: + max_map_tokens = target + + try: + files_listing = self.get_ranked_tags_map( + chat_files, other_files, max_map_tokens, mentioned_fnames, mentioned_idents + ) + except RecursionError: + self.io.tool_error("Disabling repo map, git repo too large?") + self.max_map_tokens = 0 + return + + if not files_listing: + return + + num_tokens = self.token_count(files_listing) + if self.verbose: + self.io.tool_output(f"Repo-map: {num_tokens / 1024:.1f} k-tokens") + + if chat_files: + other = "other " + else: + other = "" + + if self.repo_content_prefix: + repo_content = self.repo_content_prefix.format(other=other) + else: + repo_content = "" + + repo_content += files_listing + + return repo_content + + def get_rel_fname(self, fname): + return os.path.relpath(fname, self.root) + + def split_path(self, path): + path = os.path.relpath(path, self.root) + return [path + ":"] + + + def save_tags_cache(self): + pass + + def get_mtime(self, fname): + try: + return os.path.getmtime(fname) + except FileNotFoundError: + self.io.tool_error(f"File not found error: {fname}") + + def get_tags(self, fname, rel_fname): + # Check if the file is in the cache and if the modification time has not changed + file_mtime = self.get_mtime(fname) + if file_mtime is None: + return [] + + data = list(self.get_tags_raw(fname, rel_fname)) + + + return data + + def get_tags_raw(self, fname, rel_fname): + lang = filename_to_lang(fname) + if not lang: + return + + language = get_language(lang) + parser = get_parser(lang) + + query_scm = get_scm_fname(lang) + if not query_scm.exists(): + return + query_scm = query_scm.read_text() + + code = self.io.read_text(fname) + if not code: + return + tree = parser.parse(bytes(code, "utf-8")) + + # Run the tags queries + query = language.query(query_scm) + captures = query.captures(tree.root_node) + + captures = list(captures) + + saw = set() + for node, tag in captures: + if tag.startswith("name.definition."): + kind = "def" + type = tag.split(".")[-1] # + elif tag.startswith("name.reference."): + kind = "ref" + type = tag.split(".")[-1] # + else: + continue + + saw.add(kind) + + result = Tag( + rel_fname=rel_fname, + fname=fname, + name=node.text.decode("utf-8"), + kind=kind, + line=node.start_point[0], + end_line=node.end_point[0], + type=type, + ) + + yield result + + if "ref" in saw: + return + if "def" not in saw: + return + + # We saw defs, without any refs + # Some tags files only provide defs (cpp, for example) + # Use pygments to backfill refs + + try: + lexer = guess_lexer_for_filename(fname, code) + except ClassNotFound: + return + + tokens = list(lexer.get_tokens(code)) + tokens = [token[1] for token in tokens if token[0] in Token.Name] + + for token in tokens: + yield Tag( + rel_fname=rel_fname, + fname=fname, + name=token, + kind="ref", + line=-1, + end_line=-1, + type="unknown", + ) + + def get_ranked_tags(self, chat_fnames, other_fnames, mentioned_fnames, mentioned_idents): + defines = defaultdict(set) + references = defaultdict(list) + definitions = defaultdict(set) + + personalization = dict() + + fnames = set(chat_fnames).union(set(other_fnames)) + chat_rel_fnames = set() + + fnames = sorted(fnames) + + # Default personalization for unspecified files is 1/num_nodes + # https://networkx.org/documentation/stable/_modules/networkx/algorithms/link_analysis/pagerank_alg.html#pagerank + personalize = 100 / len(fnames) + + + fnames = tqdm(fnames) + + + for fname in fnames: + if not Path(fname).is_file(): + if fname not in self.warned_files: + if Path(fname).exists(): + self.io.tool_error( + f"Repo-map can't include {fname}, it is not a normal file" + ) + else: + self.io.tool_error(f"Repo-map can't include {fname}, it no longer exists") + + self.warned_files.add(fname) + continue + + # dump(fname) + rel_fname = self.get_rel_fname(fname) + + if fname in chat_fnames: + personalization[rel_fname] = personalize + chat_rel_fnames.add(rel_fname) + + if rel_fname in mentioned_fnames: + personalization[rel_fname] = personalize + + tags = list(self.get_tags(fname, rel_fname)) + if tags is None: + continue + + for tag in tags: + if tag.kind == "def": + defines[tag.name].add(rel_fname) + key = (rel_fname, tag.name) + definitions[key].add(tag) + + if tag.kind == "ref": + references[tag.name].append(rel_fname) + + ## + # dump(defines) + # dump(references) + # dump(personalization) + + if not references: + references = dict((k, list(v)) for k, v in defines.items()) + + idents = set(defines.keys()).intersection(set(references.keys())) + + G = nx.MultiDiGraph() + + for ident in idents: + definers = defines[ident] + if ident in mentioned_idents: + mul = 10 + elif ident.startswith("_"): + mul = 0.1 + else: + mul = 1 + + for referencer, num_refs in Counter(references[ident]).items(): + for definer in definers: + # dump(referencer, definer, num_refs, mul) + # if referencer == definer: + # continue + + # scale down so high freq (low value) mentions don't dominate + num_refs = math.sqrt(num_refs) + + G.add_edge(referencer, definer, weight=mul * num_refs, ident=ident) + + if not references: + pass + + if personalization: + pers_args = dict(personalization=personalization, dangling=personalization) + else: + pers_args = dict() + + try: + ranked = nx.pagerank(G, weight="weight", **pers_args) + except ZeroDivisionError: + return [] + + # distribute the rank from each source node, across all of its out edges + ranked_definitions = defaultdict(float) + for src in G.nodes: + src_rank = ranked[src] + total_weight = sum(data["weight"] for _src, _dst, data in G.out_edges(src, data=True)) + # dump(src, src_rank, total_weight) + for _src, dst, data in G.out_edges(src, data=True): + data["rank"] = src_rank * data["weight"] / total_weight + ident = data["ident"] + ranked_definitions[(dst, ident)] += data["rank"] + + ranked_tags = [] + ranked_definitions = sorted(ranked_definitions.items(), reverse=True, key=lambda x: x[1]) + + # dump(ranked_definitions) + + for (fname, ident), rank in ranked_definitions: + + if fname in chat_rel_fnames: + continue + ranked_tags += list(definitions.get((fname, ident), [])) + + rel_other_fnames_without_tags = set(self.get_rel_fname(fname) for fname in other_fnames) + + fnames_already_included = set(rt[0] for rt in ranked_tags) + + top_rank = sorted([(rank, node) for (node, rank) in ranked.items()], reverse=True) + for rank, fname in top_rank: + if fname in rel_other_fnames_without_tags: + rel_other_fnames_without_tags.remove(fname) + if fname not in fnames_already_included: + ranked_tags.append((fname,)) + + for fname in rel_other_fnames_without_tags: + ranked_tags.append((fname,)) + + return ranked_tags + + def get_ranked_tags_map( + self, + chat_fnames, + other_fnames=None, + max_map_tokens=None, + mentioned_fnames=None, + mentioned_idents=None, + ): + if not other_fnames: + other_fnames = list() + if not max_map_tokens: + max_map_tokens = self.max_map_tokens + if not mentioned_fnames: + mentioned_fnames = set() + if not mentioned_idents: + mentioned_idents = set() + + ranked_tags = self.get_ranked_tags( + chat_fnames, other_fnames, mentioned_fnames, mentioned_idents + ) + + num_tags = len(ranked_tags) + lower_bound = 0 + upper_bound = num_tags + best_tree = None + best_tree_tokens = 0 + + chat_rel_fnames = [self.get_rel_fname(fname) for fname in chat_fnames] + + # Guess a small starting number to help with giant repos + middle = min(max_map_tokens // 25, num_tags) + + self.tree_cache = dict() + + while lower_bound <= upper_bound: + tree = self.to_tree(ranked_tags[:middle], chat_rel_fnames) + num_tokens = self.token_count(tree) + + if num_tokens < max_map_tokens and num_tokens > best_tree_tokens: + best_tree = tree + best_tree_tokens = num_tokens + + if num_tokens < max_map_tokens: + lower_bound = middle + 1 + else: + upper_bound = middle - 1 + + middle = (lower_bound + upper_bound) // 2 + + return best_tree + + tree_cache = dict() + + def render_tree(self, abs_fname, rel_fname, lois): + key = (rel_fname, tuple(sorted(lois))) + + if key in self.tree_cache: + return self.tree_cache[key] + + code = self.io.read_text(abs_fname) or "" + if not code.endswith("\n"): + code += "\n" + + context = TreeContext( + rel_fname, + code, + color=False, + line_number=False, + child_context=False, + last_line=False, + margin=0, + mark_lois=False, + loi_pad=0, + show_top_of_file_parent_scope=False, + ) + + for start, end in lois: + context.add_lines_of_interest(range(start, end + 1)) + context.add_context() + res = context.format() + self.tree_cache[key] = res + return res + + + def create_graph(self, repo_dir): + + start_time = time.time() + logging.info("Starting parsing of codebase") # Log start + + G = nx.MultiDiGraph() + defines = defaultdict(list) + references = defaultdict(list) + file_count = 0 # Initialize file counter + for root, _, files in os.walk(repo_dir): + for file in files: + file_count += 1 # Increment file counter + logging.info(f"Processing file number: {file_count}") # Log file number + + file_path = os.path.join(root, file) + rel_path = os.path.relpath(file_path, repo_dir) + + if not self.is_text_file(file_path): + continue + + tags = self.get_tags(file_path, rel_path) + + current_class = None + current_function = None + for tag in tags: + if tag.kind == "def": + if tag.type == "class": + current_class = tag.name + current_function = None + elif tag.type == "function": + current_function = tag.name + node_name = f"{current_class}.{tag.name}@{rel_path}" if current_class else f"{rel_path}:{tag.name}" + defines[tag.name].append((node_name, tag.line, tag.end_line, tag.type, rel_path, current_class)) + G.add_node(node_name, file=rel_path, line=tag.line, end_line=tag.end_line, type=tag.type) + elif tag.kind == "ref": + source = f"{current_class}.{current_function}@{rel_path}" if current_class and current_function else f"{rel_path}:{current_function}" if current_function else rel_path + references[tag.name].append((source, tag.line, tag.end_line, tag.type, rel_path, current_class)) + + # Create edges + for ident, refs in references.items(): + if ident in defines: + if len(defines[ident]) == 1: # Unique definition + target, def_line, end_def_line, def_type, def_file, def_class = defines[ident][0] + for (source, ref_line, end_ref_line, ref_type, ref_file, ref_class) in refs: + G.add_edge(source, target, type=ref_type, ident=ident, ref_line=ref_line, end_ref_line=end_ref_line, def_line=def_line, end_def_line=end_def_line) + else: # Apply scoring system for non-unique definitions + for (source, ref_line, end_ref_line, ref_type, ref_file, ref_class) in refs: + best_match = None + best_match_score = -1 + for (target, def_line, end_def_line, def_type, def_file, def_class) in defines[ident]: + if source != target: # Avoid self-references + match_score = 0 + if ref_file == def_file: + match_score += 2 + elif os.path.dirname(ref_file) == os.path.dirname(def_file): + match_score += 1 # Add a point for being in the same directory + if ref_class == def_class: + match_score += 1 + if match_score > best_match_score: + best_match = (target, def_line, end_def_line, def_type) + best_match_score = match_score + + if best_match: + target, def_line, end_def_line, def_type = best_match + G.add_edge(source, target, type=ref_type, ident=ident, ref_line=ref_line, end_ref_line=end_ref_line, def_line=def_line, end_def_line=end_def_line) + + end_time = time.time() + logging.info(f"Parsing completed, time taken: {end_time - start_time} seconds") # Log end + return G + + def is_text_file(self, file_path): + # Simple check to determine if a file is likely to be a text file + # You might want to expand this based on your specific needs + try: + with open(file_path, 'r', encoding='utf-8') as f: + f.read(1024) + return True + except UnicodeDecodeError: + return False + def to_tree(self, tags, chat_rel_fnames): + if not tags: + return "" + + tags = [tag for tag in tags if tag[0] not in chat_rel_fnames] + tags = sorted(tags) + + cur_fname = None + cur_abs_fname = None + lois = None + output = "" + + # add a bogus tag at the end so we trip the this_fname != cur_fname... + dummy_tag = (None,) + for tag in tags + [dummy_tag]: + this_rel_fname = tag[0] + + # ... here ... to output the final real entry in the list + if this_rel_fname != cur_fname: + if lois is not None: + output += "\n" + output += cur_fname + ":\n" + output += self.render_tree(cur_abs_fname, cur_fname, lois) + lois = None + elif cur_fname: + output += "\n" + cur_fname + "\n" + if type(tag) is Tag: + lois = [] + cur_abs_fname = tag.fname + cur_fname = this_rel_fname + + if lois is not None: + lois.append((tag.line, tag.end_line)) + + # truncate long lines, in case we get minified js or something else crazy + output = "\n".join([line[:100] for line in output.splitlines()]) + "\n" + + return output + + + + + +def get_scm_fname(lang): + # Load the tags queries + try: + return Path(os.path.dirname(__file__)).joinpath("queries", f"tree-sitter-{lang}-tags.scm") + except KeyError: + return + + diff --git a/app/modules/parsing/graph_construction/parsing_router.py b/app/modules/parsing/graph_construction/parsing_router.py new file mode 100644 index 00000000..a69990e1 --- /dev/null +++ b/app/modules/parsing/graph_construction/parsing_router.py @@ -0,0 +1,119 @@ +import json +from typing import Any, List +from fastapi import Depends, HTTPException, Query +from app.modules.utils.APIRouter import APIRouter +from fastapi.responses import StreamingResponse +from app.modules.auth.auth_service import AuthService +from sqlalchemy.orm import Session +from app.core.database import get_db +import asyncio +from app.modules.projects.projects_schema import ProjectStatusEnum +from .parsing_schema import ( + ParsingRequest, + RepoDetails + ) +from app.modules.projects.projects_service import ProjectService +import os +import shutil +import logging +import traceback +from contextlib import contextmanager +from typing import Dict, Tuple +from app.modules.github.github_service import GithubService +from github import Github +from app.modules.parsing.graph_construction.parsing_helper import ParseHelper, ParsingServiceError +from app.modules.parsing.graph_construction.parsing_service import ParsingService +from uuid6 import uuid7 + +router = APIRouter() +from git import Repo, GitCommandError + + +class ParsingAPI: + @contextmanager + def change_dir(path): + old_dir = os.getcwd() + os.chdir(path) + try: + yield + finally: + os.chdir(old_dir) + + async def clone_or_copy_repository(repo_details: RepoDetails, db: Session, user_id: str) -> Tuple[Any, str, Any]: + if repo_details.repo_path: + if not os.path.exists(repo_details.repo_path): + raise HTTPException(status_code=400, detail="Local repository does not exist on given path") + repo = Repo(repo_details.repo_path) + owner = None + auth = None + else: + github_service = GithubService(db) + response, auth, owner = github_service.get_github_repo_details(repo_details.repo_name) + if response.status_code != 200: + raise HTTPException(status_code=400, detail="Failed to get installation ID") + app_auth = auth.get_installation_auth(response.json()["id"]) + github = Github(auth=app_auth) + try: + repo = github.get_repo(repo_details.repo_name) + except Exception as e: + raise HTTPException(status_code=400, detail="Repository not found on GitHub") + + return repo, owner, auth + @router.post("/parse") + async def parse_directory( + repo_details: ParsingRequest, + db: Session = Depends(get_db), + user=Depends(AuthService.check_auth) + ): + user_id = user["user_id"] + project_manager = ProjectService(db) + project_id = None + parse_helper = ParseHelper(db) + project = await project_manager.get_project_from_db(repo_details.repo_name, user_id) + extracted_dir = None + if project: + project_id = project.id + + try: + # Step 1: Validate input + ParsingAPI.validate_input(repo_details, user_id) + repo, owner, auth = await ParsingAPI.clone_or_copy_repository(repo_details, db, user_id) + + extracted_dir, project_id = await parse_helper.setup_project_directory( + repo, repo_details.branch_name, auth, repo, user_id, project_id + ) + + await ParsingService.analyze_directory(extracted_dir, project_id, user_id, db) + shutil.rmtree(extracted_dir, ignore_errors=True) + message = "The project has been parsed successfully" + await project_manager.update_project_status(project_id, ProjectStatusEnum.READY) + + + return {"message": message, "id": project_id} + + except ParsingServiceError as e: + message = str(f"{project_id} Failed during parsing: " + e.message) + await project_manager.update_project_status(project_id, ProjectStatusEnum.ERROR) + raise HTTPException(status_code=500, detail=message) + except HTTPException as http_ex: + if project_id: + await project_manager.update_project_status(project_id, ProjectStatusEnum.ERROR) + raise http_ex + except Exception as e: + if project_id: + await project_manager.update_project_status(project_id, ProjectStatusEnum.ERROR) + tb_str = "".join(traceback.format_exception(None, e, e.__traceback__)) + raise HTTPException(status_code=500, detail=f"{str(e)}\nTraceback: {tb_str}") + finally: + if extracted_dir: + shutil.rmtree(extracted_dir, ignore_errors=True) + + def validate_input(repo_details: ParsingRequest, user_id: str): + if os.getenv("isDevelopmentMode") != "enabled" and repo_details.repo_path: + raise HTTPException(status_code=403, detail="Development mode is not enabled, cannot parse local repository.") + if user_id == os.getenv("defaultUsername") and repo_details.repo_name: + raise HTTPException(status_code=403, detail="Cannot parse remote repository without auth token") + + + + diff --git a/app/modules/parsing/graph_construction/parsing_schema.py b/app/modules/parsing/graph_construction/parsing_schema.py new file mode 100644 index 00000000..9957d8f0 --- /dev/null +++ b/app/modules/parsing/graph_construction/parsing_schema.py @@ -0,0 +1,21 @@ +from typing import Optional +from pydantic import BaseModel, Field + +class ParsingRequest(BaseModel): + repo_name: Optional[str] = Field(default=None) + repo_path: Optional[str] = Field(default=None) + branch_name: str + + def __init__(self, **data): + super().__init__(**data) + if not self.repo_name and not self.repo_path: + raise ValueError('Either repo_name or repo_path must be provided.') + +class ParsingResponse(BaseModel): + message: str + status: str + project_id: str + +class RepoDetails(BaseModel): + repo_name: str + branch_name: str \ No newline at end of file diff --git a/app/modules/parsing/graph_construction/parsing_service.py b/app/modules/parsing/graph_construction/parsing_service.py new file mode 100644 index 00000000..1519c5f1 --- /dev/null +++ b/app/modules/parsing/graph_construction/parsing_service.py @@ -0,0 +1,147 @@ +import hashlib +from neo4j import GraphDatabase +from app.modules.parsing.graph_construction.parsing_helper import ParsingFailedError, RepoMap +import traceback +import logging +from blar_graph.db_managers import Neo4jManager +from blar_graph.graph_construction.core.graph_builder import GraphConstructor +from app.modules.parsing.graph_construction.parsing_helper import ParseHelper +from app.modules.projects.projects_schema import ProjectStatusEnum +from app.modules.projects.projects_service import ProjectService +from app.core.config import config_provider + +class SimpleIO: + def read_text(self, fname): + with open(fname, 'r') as f: + return f.read() + + def tool_error(self, message): + logging.error(f"Error: {message}") + + def tool_output(self, message): + logging.info(message) + +class SimpleTokenCounter: + def token_count(self, text): + return len(text.split()) + +class CodeGraphService: + def __init__(self, neo4j_uri, neo4j_user, neo4j_password): + self.driver = GraphDatabase.driver(neo4j_uri, auth=(neo4j_user, neo4j_password)) + + @staticmethod + def generate_node_id(path: str, user_id: str): + # Concatenate path and signature + combined_string = f"{user_id}:{path}" + + # Create a SHA-1 hash of the combined string + hash_object = hashlib.md5() + hash_object.update(combined_string.encode("utf-8")) + + # Get the hexadecimal representation of the hash + node_id = hash_object.hexdigest() + + return node_id + + def close(self): + self.driver.close() + + def create_and_store_graph(self, repo_dir, project_id, user_id): + # Create the graph using RepoMap + self.repo_map = RepoMap( + root=repo_dir, + verbose=True, + main_model=SimpleTokenCounter(), + io=SimpleIO(), + ) + + nx_graph = self.repo_map.create_graph(repo_dir) + + with self.driver.session() as session: + + # Create nodes + import time + + start_time = time.time() # Start timing + node_count = nx_graph.number_of_nodes() + logging.info(f"Creating {node_count} nodes") + + # Batch insert nodes + batch_size = 300 + for i in range(0, node_count, batch_size): + batch_nodes = list(nx_graph.nodes(data=True))[i:i + batch_size] + session.run( + "UNWIND $nodes AS node " + "CREATE (d:Definition {name: node.name, file: node.file, start_line: node.line, repoId: node.repoId, node_id: node.node_id, entityId: node.entityId})", + nodes=[{'name': node[0], 'file': node[1].get('file', ''), 'start_line': node[1].get('line', -1), + 'repoId': project_id, 'node_id': CodeGraphService.generate_node_id(node[1].get('file', ''), user_id), 'entityId': user_id} for node in batch_nodes] + ) + + relationship_count = nx_graph.number_of_edges() + logging.info(f"Creating {relationship_count} relationships") + + # Create relationships in batches + for i in range(0, relationship_count, batch_size): + batch_edges = list(nx_graph.edges(data=True))[i:i + batch_size] + session.run( + """ + UNWIND $edges AS edge + MATCH (s:Definition {name: edge.source}), (t:Definition {name: edge.target}) + CREATE (s)-[:REFERENCES {type: edge.type}]->(t) + """, + edges=[{'source': edge[0], 'target': edge[1], 'type': edge[2]['type']} for edge in batch_edges] + ) + + end_time = time.time() # End timing + logging.info(f"Time taken to create graph: {end_time - start_time:.2f} seconds") # Log time taken + + + def query_graph(self, query): + with self.driver.session() as session: + result = session.run(query) + return [record.data() for record in result] + + + +class ParsingService: + + + async def analyze_directory(repo_dir, project_id, user_id, db): + + + repo_lang = ParseHelper(db).detect_repo_language(repo_dir) + + if repo_lang in [ "python", "javascript", "typescript"]: + + graph_manager = Neo4jManager(project_id, user_id) + + try: + graph_constructor = GraphConstructor(graph_manager, user_id) + n, r = graph_constructor.build_graph(repo_dir) + graph_manager.save_graph(n, r) + await ProjectService(db).update_project_status(project_id, ProjectStatusEnum.PARSED) + except Exception as e: + logging.error(e) + logging.error(traceback.format_exc()) + + finally: + graph_manager.close() + elif repo_lang != "other": + try: + neo4j_config = config_provider.get_neo4j_config() + service = CodeGraphService( + neo4j_config["uri"], + neo4j_config["username"], + neo4j_config["password"] + ) + + service.create_and_store_graph(repo_dir, project_id, user_id) + await ProjectService(db).update_project_status(project_id, ProjectStatusEnum.PARSED) + + finally: + service.close() + else: + await ProjectService(db).update_project_status(project_id, ProjectStatusEnum.ERROR) + return ParsingFailedError("Repository doesn't consist of a language currently supported.") + + \ No newline at end of file diff --git a/app/modules/parsing/graph_construction/queries/tree-sitter-c-tags.scm b/app/modules/parsing/graph_construction/queries/tree-sitter-c-tags.scm new file mode 100644 index 00000000..1035aa22 --- /dev/null +++ b/app/modules/parsing/graph_construction/queries/tree-sitter-c-tags.scm @@ -0,0 +1,9 @@ +(struct_specifier name: (type_identifier) @name.definition.class body:(_)) @definition.class + +(declaration type: (union_specifier name: (type_identifier) @name.definition.class)) @definition.class + +(function_declarator declarator: (identifier) @name.definition.function) @definition.function + +(type_definition declarator: (type_identifier) @name.definition.type) @definition.type + +(enum_specifier name: (type_identifier) @name.definition.type) @definition.type diff --git a/app/modules/parsing/graph_construction/queries/tree-sitter-c_sharp-tags.scm b/app/modules/parsing/graph_construction/queries/tree-sitter-c_sharp-tags.scm new file mode 100644 index 00000000..58e9199a --- /dev/null +++ b/app/modules/parsing/graph_construction/queries/tree-sitter-c_sharp-tags.scm @@ -0,0 +1,46 @@ +(class_declaration + name: (identifier) @name.definition.class + ) @definition.class + +(class_declaration + bases: (base_list (_) @name.reference.class) + ) @reference.class + +(interface_declaration + name: (identifier) @name.definition.interface + ) @definition.interface + +(interface_declaration + bases: (base_list (_) @name.reference.interface) + ) @reference.interface + +(method_declaration + name: (identifier) @name.definition.method + ) @definition.method + +(object_creation_expression + type: (identifier) @name.reference.class + ) @reference.class + +(type_parameter_constraints_clause + target: (identifier) @name.reference.class + ) @reference.class + +(type_constraint + type: (identifier) @name.reference.class + ) @reference.class + +(variable_declaration + type: (identifier) @name.reference.class + ) @reference.class + +(invocation_expression + function: + (member_access_expression + name: (identifier) @name.reference.send + ) +) @reference.send + +(namespace_declaration + name: (identifier) @name.definition.module +) @definition.module diff --git a/app/modules/parsing/graph_construction/queries/tree-sitter-cpp-tags.scm b/app/modules/parsing/graph_construction/queries/tree-sitter-cpp-tags.scm new file mode 100644 index 00000000..7a7ad0b9 --- /dev/null +++ b/app/modules/parsing/graph_construction/queries/tree-sitter-cpp-tags.scm @@ -0,0 +1,15 @@ +(struct_specifier name: (type_identifier) @name.definition.class body:(_)) @definition.class + +(declaration type: (union_specifier name: (type_identifier) @name.definition.class)) @definition.class + +(function_declarator declarator: (identifier) @name.definition.function) @definition.function + +(function_declarator declarator: (field_identifier) @name.definition.function) @definition.function + +(function_declarator declarator: (qualified_identifier scope: (namespace_identifier) @scope name: (identifier) @name.definition.method)) @definition.method + +(type_definition declarator: (type_identifier) @name.definition.type) @definition.type + +(enum_specifier name: (type_identifier) @name.definition.type) @definition.type + +(class_specifier name: (type_identifier) @name.definition.class) @definition.class diff --git a/app/modules/parsing/graph_construction/queries/tree-sitter-elisp-tags.scm b/app/modules/parsing/graph_construction/queries/tree-sitter-elisp-tags.scm new file mode 100644 index 00000000..743c8d8a --- /dev/null +++ b/app/modules/parsing/graph_construction/queries/tree-sitter-elisp-tags.scm @@ -0,0 +1,8 @@ +;; defun/defsubst +(function_definition name: (symbol) @name.definition.function) @definition.function + +;; Treat macros as function definitions for the sake of TAGS. +(macro_definition name: (symbol) @name.definition.function) @definition.function + +;; Match function calls +(list (symbol) @name.reference.function) @reference.function diff --git a/app/modules/parsing/graph_construction/queries/tree-sitter-elixir-tags.scm b/app/modules/parsing/graph_construction/queries/tree-sitter-elixir-tags.scm new file mode 100644 index 00000000..9eb39d95 --- /dev/null +++ b/app/modules/parsing/graph_construction/queries/tree-sitter-elixir-tags.scm @@ -0,0 +1,54 @@ +; Definitions + +; * modules and protocols +(call + target: (identifier) @ignore + (arguments (alias) @name.definition.module) + (#match? @ignore "^(defmodule|defprotocol)$")) @definition.module + +; * functions/macros +(call + target: (identifier) @ignore + (arguments + [ + ; zero-arity functions with no parentheses + (identifier) @name.definition.function + ; regular function clause + (call target: (identifier) @name.definition.function) + ; function clause with a guard clause + (binary_operator + left: (call target: (identifier) @name.definition.function) + operator: "when") + ]) + (#match? @ignore "^(def|defp|defdelegate|defguard|defguardp|defmacro|defmacrop|defn|defnp)$")) @definition.function + +; References + +; ignore calls to kernel/special-forms keywords +(call + target: (identifier) @ignore + (#match? @ignore "^(def|defp|defdelegate|defguard|defguardp|defmacro|defmacrop|defn|defnp|defmodule|defprotocol|defimpl|defstruct|defexception|defoverridable|alias|case|cond|else|for|if|import|quote|raise|receive|require|reraise|super|throw|try|unless|unquote|unquote_splicing|use|with)$")) + +; ignore module attributes +(unary_operator + operator: "@" + operand: (call + target: (identifier) @ignore)) + +; * function call +(call + target: [ + ; local + (identifier) @name.reference.call + ; remote + (dot + right: (identifier) @name.reference.call) + ]) @reference.call + +; * pipe into function call +(binary_operator + operator: "|>" + right: (identifier) @name.reference.call) @reference.call + +; * modules +(alias) @name.reference.module @reference.module diff --git a/app/modules/parsing/graph_construction/queries/tree-sitter-elm-tags.scm b/app/modules/parsing/graph_construction/queries/tree-sitter-elm-tags.scm new file mode 100644 index 00000000..8b1589e9 --- /dev/null +++ b/app/modules/parsing/graph_construction/queries/tree-sitter-elm-tags.scm @@ -0,0 +1,19 @@ +(value_declaration (function_declaration_left (lower_case_identifier) @name.definition.function)) @definition.function + +(function_call_expr (value_expr (value_qid) @name.reference.function)) @reference.function +(exposed_value (lower_case_identifier) @name.reference.function) @reference.function +(type_annotation ((lower_case_identifier) @name.reference.function) (colon)) @reference.function + +(type_declaration ((upper_case_identifier) @name.definition.type)) @definition.type + +(type_ref (upper_case_qid (upper_case_identifier) @name.reference.type)) @reference.type +(exposed_type (upper_case_identifier) @name.reference.type) @reference.type + +(type_declaration (union_variant (upper_case_identifier) @name.definition.union)) @definition.union + +(value_expr (upper_case_qid (upper_case_identifier) @name.reference.union)) @reference.union + + +(module_declaration + (upper_case_qid (upper_case_identifier)) @name.definition.module +) @definition.module diff --git a/app/modules/parsing/graph_construction/queries/tree-sitter-go-tags.scm b/app/modules/parsing/graph_construction/queries/tree-sitter-go-tags.scm new file mode 100644 index 00000000..a32d03aa --- /dev/null +++ b/app/modules/parsing/graph_construction/queries/tree-sitter-go-tags.scm @@ -0,0 +1,30 @@ +( + (comment)* @doc + . + (function_declaration + name: (identifier) @name.definition.function) @definition.function + (#strip! @doc "^//\\s*") + (#set-adjacent! @doc @definition.function) +) + +( + (comment)* @doc + . + (method_declaration + name: (field_identifier) @name.definition.method) @definition.method + (#strip! @doc "^//\\s*") + (#set-adjacent! @doc @definition.method) +) + +(call_expression + function: [ + (identifier) @name.reference.call + (parenthesized_expression (identifier) @name.reference.call) + (selector_expression field: (field_identifier) @name.reference.call) + (parenthesized_expression (selector_expression field: (field_identifier) @name.reference.call)) + ]) @reference.call + +(type_spec + name: (type_identifier) @name.definition.type) @definition.type + +(type_identifier) @name.reference.type @reference.type diff --git a/app/modules/parsing/graph_construction/queries/tree-sitter-java-tags.scm b/app/modules/parsing/graph_construction/queries/tree-sitter-java-tags.scm new file mode 100644 index 00000000..3b7290d4 --- /dev/null +++ b/app/modules/parsing/graph_construction/queries/tree-sitter-java-tags.scm @@ -0,0 +1,20 @@ +(class_declaration + name: (identifier) @name.definition.class) @definition.class + +(method_declaration + name: (identifier) @name.definition.method) @definition.method + +(method_invocation + name: (identifier) @name.reference.call + arguments: (argument_list) @reference.call) + +(interface_declaration + name: (identifier) @name.definition.interface) @definition.interface + +(type_list + (type_identifier) @name.reference.implementation) @reference.implementation + +(object_creation_expression + type: (type_identifier) @name.reference.class) @reference.class + +(superclass (type_identifier) @name.reference.class) @reference.class diff --git a/app/modules/parsing/graph_construction/queries/tree-sitter-javascript-tags.scm b/app/modules/parsing/graph_construction/queries/tree-sitter-javascript-tags.scm new file mode 100644 index 00000000..3bc55c5c --- /dev/null +++ b/app/modules/parsing/graph_construction/queries/tree-sitter-javascript-tags.scm @@ -0,0 +1,88 @@ +( + (comment)* @doc + . + (method_definition + name: (property_identifier) @name.definition.method) @definition.method + (#not-eq? @name.definition.method "constructor") + (#strip! @doc "^[\\s\\*/]+|^[\\s\\*/]$") + (#select-adjacent! @doc @definition.method) +) + +( + (comment)* @doc + . + [ + (class + name: (_) @name.definition.class) + (class_declaration + name: (_) @name.definition.class) + ] @definition.class + (#strip! @doc "^[\\s\\*/]+|^[\\s\\*/]$") + (#select-adjacent! @doc @definition.class) +) + +( + (comment)* @doc + . + [ + (function + name: (identifier) @name.definition.function) + (function_declaration + name: (identifier) @name.definition.function) + (generator_function + name: (identifier) @name.definition.function) + (generator_function_declaration + name: (identifier) @name.definition.function) + ] @definition.function + (#strip! @doc "^[\\s\\*/]+|^[\\s\\*/]$") + (#select-adjacent! @doc @definition.function) +) + +( + (comment)* @doc + . + (lexical_declaration + (variable_declarator + name: (identifier) @name.definition.function + value: [(arrow_function) (function)]) @definition.function) + (#strip! @doc "^[\\s\\*/]+|^[\\s\\*/]$") + (#select-adjacent! @doc @definition.function) +) + +( + (comment)* @doc + . + (variable_declaration + (variable_declarator + name: (identifier) @name.definition.function + value: [(arrow_function) (function)]) @definition.function) + (#strip! @doc "^[\\s\\*/]+|^[\\s\\*/]$") + (#select-adjacent! @doc @definition.function) +) + +(assignment_expression + left: [ + (identifier) @name.definition.function + (member_expression + property: (property_identifier) @name.definition.function) + ] + right: [(arrow_function) (function)] +) @definition.function + +(pair + key: (property_identifier) @name.definition.function + value: [(arrow_function) (function)]) @definition.function + +( + (call_expression + function: (identifier) @name.reference.call) @reference.call + (#not-match? @name.reference.call "^(require)$") +) + +(call_expression + function: (member_expression + property: (property_identifier) @name.reference.call) + arguments: (_) @reference.call) + +(new_expression + constructor: (_) @name.reference.class) @reference.class diff --git a/app/modules/parsing/graph_construction/queries/tree-sitter-ocaml-tags.scm b/app/modules/parsing/graph_construction/queries/tree-sitter-ocaml-tags.scm new file mode 100644 index 00000000..52d5a857 --- /dev/null +++ b/app/modules/parsing/graph_construction/queries/tree-sitter-ocaml-tags.scm @@ -0,0 +1,115 @@ +; Modules +;-------- + +( + (comment)? @doc . + (module_definition (module_binding (module_name) @name.definition.module) @definition.module) + (#strip! @doc "^\\(\\*\\*?\\s*|\\s\\*\\)$") +) + +(module_path (module_name) @name.reference.module) @reference.module + +; Module types +;-------------- + +( + (comment)? @doc . + (module_type_definition (module_type_name) @name.definition.interface) @definition.interface + (#strip! @doc "^\\(\\*\\*?\\s*|\\s\\*\\)$") +) + +(module_type_path (module_type_name) @name.reference.implementation) @reference.implementation + +; Functions +;---------- + +( + (comment)? @doc . + (value_definition + [ + (let_binding + pattern: (value_name) @name.definition.function + (parameter)) + (let_binding + pattern: (value_name) @name.definition.function + body: [(fun_expression) (function_expression)]) + ] @definition.function + ) + (#strip! @doc "^\\(\\*\\*?\\s*|\\s\\*\\)$") +) + +( + (comment)? @doc . + (external (value_name) @name.definition.function) @definition.function + (#strip! @doc "^\\(\\*\\*?\\s*|\\s\\*\\)$") +) + +(application_expression + function: (value_path (value_name) @name.reference.call)) @reference.call + +(infix_expression + left: (value_path (value_name) @name.reference.call) + operator: (concat_operator) @reference.call + (#eq? @reference.call "@@")) + +(infix_expression + operator: (rel_operator) @reference.call + right: (value_path (value_name) @name.reference.call) + (#eq? @reference.call "|>")) + +; Operator +;--------- + +( + (comment)? @doc . + (value_definition + (let_binding + pattern: (parenthesized_operator (_) @name.definition.function)) @definition.function) + (#strip! @doc "^\\(\\*\\*?\\s*|\\s\\*\\)$") +) + +[ + (prefix_operator) + (sign_operator) + (pow_operator) + (mult_operator) + (add_operator) + (concat_operator) + (rel_operator) + (and_operator) + (or_operator) + (assign_operator) + (hash_operator) + (indexing_operator) + (let_operator) + (let_and_operator) + (match_operator) +] @name.reference.call @reference.call + +; Classes +;-------- + +( + (comment)? @doc . + [ + (class_definition (class_binding (class_name) @name.definition.class) @definition.class) + (class_type_definition (class_type_binding (class_type_name) @name.definition.class) @definition.class) + ] + (#strip! @doc "^\\(\\*\\*?\\s*|\\s\\*\\)$") +) + +[ + (class_path (class_name) @name.reference.class) + (class_type_path (class_type_name) @name.reference.class) +] @reference.class + +; Methods +;-------- + +( + (comment)? @doc . + (method_definition (method_name) @name.definition.method) @definition.method + (#strip! @doc "^\\(\\*\\*?\\s*|\\s\\*\\)$") +) + +(method_invocation (method_name) @name.reference.call) @reference.call diff --git a/app/modules/parsing/graph_construction/queries/tree-sitter-php-tags.scm b/app/modules/parsing/graph_construction/queries/tree-sitter-php-tags.scm new file mode 100644 index 00000000..61c86fcb --- /dev/null +++ b/app/modules/parsing/graph_construction/queries/tree-sitter-php-tags.scm @@ -0,0 +1,26 @@ +(class_declaration + name: (name) @name.definition.class) @definition.class + +(function_definition + name: (name) @name.definition.function) @definition.function + +(method_declaration + name: (name) @name.definition.function) @definition.function + +(object_creation_expression + [ + (qualified_name (name) @name.reference.class) + (variable_name (name) @name.reference.class) + ]) @reference.class + +(function_call_expression + function: [ + (qualified_name (name) @name.reference.call) + (variable_name (name)) @name.reference.call + ]) @reference.call + +(scoped_call_expression + name: (name) @name.reference.call) @reference.call + +(member_call_expression + name: (name) @name.reference.call) @reference.call diff --git a/app/modules/parsing/graph_construction/queries/tree-sitter-python-tags.scm b/app/modules/parsing/graph_construction/queries/tree-sitter-python-tags.scm new file mode 100644 index 00000000..3be5bed9 --- /dev/null +++ b/app/modules/parsing/graph_construction/queries/tree-sitter-python-tags.scm @@ -0,0 +1,12 @@ +(class_definition + name: (identifier) @name.definition.class) @definition.class + +(function_definition + name: (identifier) @name.definition.function) @definition.function + +(call + function: [ + (identifier) @name.reference.call + (attribute + attribute: (identifier) @name.reference.call) + ]) @reference.call diff --git a/app/modules/parsing/graph_construction/queries/tree-sitter-ql-tags.scm b/app/modules/parsing/graph_construction/queries/tree-sitter-ql-tags.scm new file mode 100644 index 00000000..3164aa25 --- /dev/null +++ b/app/modules/parsing/graph_construction/queries/tree-sitter-ql-tags.scm @@ -0,0 +1,26 @@ +(classlessPredicate + name: (predicateName) @name.definition.function) @definition.function + +(memberPredicate + name: (predicateName) @name.definition.method) @definition.method + +(aritylessPredicateExpr + name: (literalId) @name.reference.call) @reference.call + +(module + name: (moduleName) @name.definition.module) @definition.module + +(dataclass + name: (className) @name.definition.class) @definition.class + +(datatype + name: (className) @name.definition.class) @definition.class + +(datatypeBranch + name: (className) @name.definition.class) @definition.class + +(qualifiedRhs + name: (predicateName) @name.reference.call) @reference.call + +(typeExpr + name: (className) @name.reference.type) @reference.type diff --git a/app/modules/parsing/graph_construction/queries/tree-sitter-ruby-tags.scm b/app/modules/parsing/graph_construction/queries/tree-sitter-ruby-tags.scm new file mode 100644 index 00000000..79e71d2d --- /dev/null +++ b/app/modules/parsing/graph_construction/queries/tree-sitter-ruby-tags.scm @@ -0,0 +1,64 @@ +; Method definitions + +( + (comment)* @doc + . + [ + (method + name: (_) @name.definition.method) @definition.method + (singleton_method + name: (_) @name.definition.method) @definition.method + ] + (#strip! @doc "^#\\s*") + (#select-adjacent! @doc @definition.method) +) + +(alias + name: (_) @name.definition.method) @definition.method + +(setter + (identifier) @ignore) + +; Class definitions + +( + (comment)* @doc + . + [ + (class + name: [ + (constant) @name.definition.class + (scope_resolution + name: (_) @name.definition.class) + ]) @definition.class + (singleton_class + value: [ + (constant) @name.definition.class + (scope_resolution + name: (_) @name.definition.class) + ]) @definition.class + ] + (#strip! @doc "^#\\s*") + (#select-adjacent! @doc @definition.class) +) + +; Module definitions + +( + (module + name: [ + (constant) @name.definition.module + (scope_resolution + name: (_) @name.definition.module) + ]) @definition.module +) + +; Calls + +(call method: (identifier) @name.reference.call) @reference.call + +( + [(identifier) (constant)] @name.reference.call @reference.call + (#is-not? local) + (#not-match? @name.reference.call "^(lambda|load|require|require_relative|__FILE__|__LINE__)$") +) diff --git a/app/modules/parsing/graph_construction/queries/tree-sitter-rust-tags.scm b/app/modules/parsing/graph_construction/queries/tree-sitter-rust-tags.scm new file mode 100644 index 00000000..dadfa7ac --- /dev/null +++ b/app/modules/parsing/graph_construction/queries/tree-sitter-rust-tags.scm @@ -0,0 +1,60 @@ +; ADT definitions + +(struct_item + name: (type_identifier) @name.definition.class) @definition.class + +(enum_item + name: (type_identifier) @name.definition.class) @definition.class + +(union_item + name: (type_identifier) @name.definition.class) @definition.class + +; type aliases + +(type_item + name: (type_identifier) @name.definition.class) @definition.class + +; method definitions + +(declaration_list + (function_item + name: (identifier) @name.definition.method)) @definition.method + +; function definitions + +(function_item + name: (identifier) @name.definition.function) @definition.function + +; trait definitions +(trait_item + name: (type_identifier) @name.definition.interface) @definition.interface + +; module definitions +(mod_item + name: (identifier) @name.definition.module) @definition.module + +; macro definitions + +(macro_definition + name: (identifier) @name.definition.macro) @definition.macro + +; references + +(call_expression + function: (identifier) @name.reference.call) @reference.call + +(call_expression + function: (field_expression + field: (field_identifier) @name.reference.call)) @reference.call + +(macro_invocation + macro: (identifier) @name.reference.call) @reference.call + +; implementations + +(impl_item + trait: (type_identifier) @name.reference.implementation) @reference.implementation + +(impl_item + type: (type_identifier) @name.reference.implementation + !trait) @reference.implementation diff --git a/app/modules/parsing/graph_construction/queries/tree-sitter-typescript-tags.scm b/app/modules/parsing/graph_construction/queries/tree-sitter-typescript-tags.scm new file mode 100644 index 00000000..8a73dccc --- /dev/null +++ b/app/modules/parsing/graph_construction/queries/tree-sitter-typescript-tags.scm @@ -0,0 +1,41 @@ +(function_signature + name: (identifier) @name.definition.function) @definition.function + +(method_signature + name: (property_identifier) @name.definition.method) @definition.method + +(abstract_method_signature + name: (property_identifier) @name.definition.method) @definition.method + +(abstract_class_declaration + name: (type_identifier) @name.definition.class) @definition.class + +(module + name: (identifier) @name.definition.module) @definition.module + +(interface_declaration + name: (type_identifier) @name.definition.interface) @definition.interface + +(type_annotation + (type_identifier) @name.reference.type) @reference.type + +(new_expression + constructor: (identifier) @name.reference.class) @reference.class + +(function_declaration + name: (identifier) @name.definition.function) @definition.function + +(method_definition + name: (property_identifier) @name.definition.method) @definition.method + +(class_declaration + name: (type_identifier) @name.definition.class) @definition.class + +(interface_declaration + name: (type_identifier) @name.definition.class) @definition.class + +(type_alias_declaration + name: (type_identifier) @name.definition.type) @definition.type + +(enum_declaration + name: (identifier) @name.definition.enum) @definition.enum diff --git a/app/modules/projects/projects_model.py b/app/modules/projects/projects_model.py index 13be0a9e..e4c987c1 100644 --- a/app/modules/projects/projects_model.py +++ b/app/modules/projects/projects_model.py @@ -1,22 +1,13 @@ -from sqlalchemy import Column, ForeignKey, String, TIMESTAMP, Boolean, CheckConstraint, func, Integer, Text +from sqlalchemy import Column, ForeignKey, String, TIMESTAMP, Boolean, CheckConstraint, ForeignKeyConstraint, func, Integer, Text from sqlalchemy.dialects.postgresql import BYTEA from app.core.database import Base from sqlalchemy.orm import relationship -import enum - -class ProjectStatusEnum(str, enum.Enum): - CREATED = 'created' - READY = 'ready' - ERROR = 'error' class Project(Base): __tablename__ = "projects" - id = Column(Integer, primary_key=True) - directory = Column(Text, unique=True) - is_default = Column(Boolean, default=False) - project_name = Column(Text) + id = Column(Text, primary_key=True) properties = Column(BYTEA) repo_name = Column(Text) branch_name = Column(Text) @@ -28,7 +19,8 @@ class Project(Base): status = Column(String(255), default='created') __table_args__ = ( - CheckConstraint("status IN ('created', 'ready', 'error')", name='check_status'), + ForeignKeyConstraint(["user_id"], ["users.uid"], ondelete="CASCADE"), + CheckConstraint("status IN ('submitted', 'cloned', 'parsed', 'ready', 'error')", name='check_status'), ) # Project relationships diff --git a/app/modules/projects/projects_schema.py b/app/modules/projects/projects_schema.py new file mode 100644 index 00000000..a1d75d11 --- /dev/null +++ b/app/modules/projects/projects_schema.py @@ -0,0 +1,10 @@ +from enum import Enum + + +class ProjectStatusEnum(str, Enum): + SUBMITTED = 'submitted' + CLONED = 'cloned' + PARSED = 'parsed' + PROCESSING = 'processing' + READY = 'ready' + ERROR = 'error' diff --git a/app/modules/projects/projects_service.py b/app/modules/projects/projects_service.py index 82b71029..b21b89a3 100644 --- a/app/modules/projects/projects_service.py +++ b/app/modules/projects/projects_service.py @@ -2,6 +2,14 @@ from sqlalchemy.orm import Session from sqlalchemy.exc import SQLAlchemyError from app.modules.projects.projects_model import Project +from app.modules.projects.projects_schema import ProjectStatusEnum +from app.modules.utils.model_helper import ModelHelper +import logging +from fastapi import HTTPException +from datetime import datetime +import os +import shutil +from sqlalchemy import and_ logger = logging.getLogger(__name__) @@ -31,4 +39,253 @@ async def get_project_name(self, project_ids: list) -> str: raise except Exception as e: logger.error(f"Unexpected error in get_project_name for project IDs {project_ids}: {e}", exc_info=True) - raise ProjectServiceError(f"An unexpected error occurred while retrieving project name for project IDs {project_ids}") from e \ No newline at end of file + raise ProjectServiceError(f"An unexpected error occurred while retrieving project name for project IDs {project_ids}") from e + + + async def register_project(self, repo_name: str, branch_name: str, user_id: str, project_id: str): + + + + project = Project(id=project_id, repo_name=repo_name, + branch_name=branch_name, user_id=user_id, + status=ProjectStatusEnum.SUBMITTED.value) + project = ProjectService.create_project(self.db, project) + message = f"Project id '{project.id}' for repo '{repo_name}' and branch '{branch_name}' registered successfully." + logging.info(message) + return project_id + + async def list_projects(self, user_id: str): + projects = ProjectService.get_projects_by_user_id(self.db, user_id) + project_list = [] + for project in projects: + project_dict = { + "id": project.id, + "directory": project.directory, + "active": project.is_default, + } + project_list.append(project_dict) + return project_list + + async def update_project_status(self, project_id: int, status: ProjectStatusEnum): + ProjectService.update_project(self.db, project_id, status=status.value) + logging.info(f"Project with ID {project_id} has now been updated with status {status}.") + + + async def get_project_from_db(self, repo_name: str, user_id: str): + project = self.db.query(Project).filter(Project.repo_name == repo_name, Project.user_id == user_id).first() + if project: + return project + else: + return None + + async def get_project_from_db_by_id(self, project_id: int): + project = ProjectService.get_project_by_id(self.db, project_id) + if project: + return { + "project_name": project.project_name, + "directory": project.directory, + "id": project.id, + "commit_id": project.commit_id, + "status": project.status + } + else: + return None + + async def get_project_repo_details_from_db(self, project_id: int, user_id: str): + project = self.db.query(Project).filter(Project.id == project_id, Project.user_id == user_id).first() + if project: + return { + "project_name": project.project_name, + "directory": project.directory, + "id": project.id, + "repo_name": project.repo_name, + "branch_name": project.branch_name + } + else: + return None + + async def get_repo_and_branch_name(self, project_id: int): + project = ProjectService.get_project_by_id(self.db, project_id) + if project: + return project.repo_name, project.branch_name , project.directory + else: + return None + + async def get_project_from_db_by_id_and_user_id(self, project_id: int, user_id: str): + project = self.db.query(Project).filter(Project.id == project_id, Project.user_id == user_id).first() + if project: + return { + 'project_name': project.project_name, + 'directory': project.directory, + 'id': project.id, + 'commit_id': project.commit_id, + 'status': project.status + } + else: + return None + + async def get_first_project_from_db_by_repo_name_branch_name(self, repo_name, branch_name): + project = self.db.query(Project).filter(Project.repo_name == repo_name, Project.branch_name == branch_name).first() + if project: + return ModelHelper.model_to_dict(project) + else: + return None + + async def get_first_user_id_from_project_repo_name(self, repo_name): + project = self.db.query(Project).filter(Project.repo_name == repo_name).first() + if project: + return project.user_id + else: + return None + + async def get_parsed_project_branches(self, repo_name: str = None, user_id: str = None, default: bool = None): + query = self.db.query(Project).filter(Project.user_id == user_id) + if default is not None: + query = query.filter(Project.is_default == default) + if repo_name is not None: + query = query.filter(Project.repo_name == repo_name) + projects = query.all() + return [(p.id, p.branch_name, p.repo_name, p.updated_at, p.is_default, p.status) for p in projects] + + + async def delete_project(self, project_id: int, user_id: str): + try: + result = ProjectService.update_project( + self.db, + project_id, + is_deleted=True, + updated_at=datetime.utcnow(), + user_id=user_id + ) + if not result: + raise HTTPException( + status_code=404, + detail="No matching project found or project is already deleted." + ) + else: + is_local_repo = os.getenv("isDevelopmentMode") == "enabled" and user_id == os.getenv("defaultUsername") + if is_local_repo: + project_path = self.get_project_repo_details_from_db(project_id,user_id)['directory'] + if os.path.exists(project_path): + shutil.rmtree(project_path) + logging.info(f"Deleted local project folder: {project_path}") + else: + logging.warning(f"Local project folder not found: {project_path}") + + + logging.info(f"Project {project_id} deleted successfully.") + + except Exception as e: + self.db.rollback() + raise HTTPException( + status_code=500, + detail=f"An error occurred while deleting the project: {str(e)}" + ) + + async def restore_project(self, project_id: int, user_id: str): + try: + result = ProjectService.update_project( + self.db, + project_id, + is_deleted=False, + user_id=user_id + ) + if result: + message = f"Project with ID {project_id} restored successfully." + else: + message = "Project not found or already restored." + logging.info(message) + return message + except Exception as e: + self.db.rollback() + logging.error(f"An error occurred: {e}") + return "Error occurred during restoration." + + async def restore_all_project(self, repo_name: str, user_id: str): + try: + projects = ProjectService.get_projects_by_repo_name(self.db, repo_name, user_id, is_deleted=True) + for project in projects: + ProjectService.update_project(self.db, project.id, is_deleted=False) + if projects: + message = f"Projects with repo_name {repo_name} restored successfully." + else: + message = "Projects not found or already restored." + logging.info(message) + return message + except Exception as e: + self.db.rollback() + logging.error(f"An error occurred: {e}") + return "Error occurred during restoration." + + async def delete_all_project_by_repo_name(self, repo_name: str, user_id: str): + try: + projects = ProjectService.get_projects_by_repo_name(self.db, repo_name, user_id, is_deleted=False) + for project in projects: + ProjectService.update_project(self.db, project.id, is_deleted=True) + if projects: + message = f"Projects with repo_name {repo_name} deleted successfully." + else: + message = "Projects not found or already deleted." + logging.info(message) + return message + except Exception as e: + self.db.rollback() + logging.error(f"An error occurred: {e}") + return "Error occurred during deletion." + + + def get_project_by_id(db: Session, project_id: int): + return db.query(Project).filter(Project.id == project_id).first() + + def get_projects_by_user_id(db: Session, user_id: str): + return db.query(Project).filter(Project.user_id == user_id).all() + + def create_project(db: Session, project: Project): + project.created_at = datetime.utcnow() + project.updated_at = datetime.utcnow() + db.add(project) + db.commit() + db.refresh(project) + return project + + + def update_project(db: Session, project_id: int, **kwargs): + project = db.query(Project).filter(Project.id == project_id).first() + + if project is None: + return None # Project doesn't exist + + result = db.query(Project).filter(Project.id == project_id).update(kwargs) + + if result > 0: + db.commit() + return result + + return None + + def delete_project(db: Session, project_id: int): + db.query(Project).filter(Project.id == project_id).delete() + db.commit() + + + def get_projects_by_repo_name(db: Session, repo_name: str, user_id: str, is_deleted: bool = False): + try: + projects = db.query(Project).filter( + and_( + Project.repo_name == repo_name, + Project.user_id == user_id, + Project.is_deleted == is_deleted + ) + ).all() + + return projects + except Exception as e: + db.rollback() + # Log the error + logging.error(f"Error fetching projects: {str(e)}") + # You might want to raise a custom exception here instead of returning None + return None + + + + diff --git a/app/modules/users/user_router.py b/app/modules/users/user_router.py index 8bf00b40..4877b0f0 100644 --- a/app/modules/users/user_router.py +++ b/app/modules/users/user_router.py @@ -1,5 +1,6 @@ from typing import List -from fastapi import APIRouter, Depends, Query +from fastapi import Depends, Query +from app.modules.utils.APIRouter import APIRouter from sqlalchemy.orm import Session from app.core.database import get_db from app.modules.users.user_controller import UserController @@ -17,4 +18,7 @@ async def get_conversations_for_user( db: Session = Depends(get_db) ): controller = UserController(db) - return await controller.get_conversations_for_user(user_id, start, limit) \ No newline at end of file + return await controller.get_conversations_for_user(user_id, start, limit) + + + diff --git a/app/modules/users/user_schema.py b/app/modules/users/user_schema.py index fa619b6c..57a8ae9a 100644 --- a/app/modules/users/user_schema.py +++ b/app/modules/users/user_schema.py @@ -1,5 +1,6 @@ from typing import List, Optional from pydantic import BaseModel +from datetime import datetime class UserConversationListRequest(BaseModel): user_id: str @@ -14,3 +15,13 @@ class UserConversationListResponse(BaseModel): project_ids: Optional[List[str]] created_at: str updated_at: str + +class CreateUser(BaseModel): + uid: str + email: str + display_name: str + email_verified: bool + created_at: datetime + last_login_at: datetime + provider_info: dict + provider_username: str diff --git a/app/modules/users/user_service.py b/app/modules/users/user_service.py index bf3ecfac..0f5d7b56 100644 --- a/app/modules/users/user_service.py +++ b/app/modules/users/user_service.py @@ -3,6 +3,13 @@ from sqlalchemy.exc import SQLAlchemyError from typing import List from app.modules.conversations.conversation.conversation_model import Conversation +from datetime import datetime + +import logging + +from app.modules.users.user_schema import CreateUser +from app.modules.users.user_model import User + logger = logging.getLogger(__name__) @@ -29,4 +36,98 @@ def get_conversations_for_user(self, user_id: str, start: int, limit: int) -> Li raise UserServiceError(f"Failed to retrieve conversations for user {user_id}") from e except Exception as e: logger.error(f"Unexpected error in get_conversations_for_user for user {user_id}: {e}", exc_info=True) - raise UserServiceError(f"An unexpected error occurred while retrieving conversations for user {user_id}") from e \ No newline at end of file + raise UserServiceError(f"An unexpected error occurred while retrieving conversations for user {user_id}") from e + + + def update_last_login(self, uid: str): + logging.info(f"Updating last login time for user with UID: {uid}") + message: str = "" + error: bool = False + try: + + user = self.db.query(User).filter(User.uid == uid).first() + if user: + user.last_login_at = datetime.utcnow() + self.db.commit() + self.db.refresh(user) + error = False + message = f"Updated last login time for user with ID: {user.uid}" + else: + error = True + message = "User not found" + except Exception as e: + logging.error(f"Error updating last login time: {e}") + message = "Error updating last login time" + error = True + + return message, error + + def create_user(self, user_details: CreateUser): + logging.info( + f"Creating user with email: {user_details.email} | display_name:" + f" {user_details.display_name}" + ) + new_user = User( + uid=user_details.uid, + email=user_details.email, + display_name=user_details.display_name, + email_verified=user_details.email_verified, + created_at=user_details.created_at, + last_login_at=user_details.last_login_at, + provider_info=user_details.provider_info, + provider_username=user_details.provider_username, + ) + message: str = "" + error: bool = False + try: + + self.db.add(new_user) + self.db.commit() + self.db.refresh(new_user) + error = False + message = f"User created with ID: {new_user.uid}" + uid = new_user.uid + + except Exception as e: + logging.error(f"Error creating user: {e}") + message = "error creating user" + error = True + uid = "" + + return uid, message, error + + def get_user_by_uid(self, uid: str): + try: + user = self.db.query(User).filter(User.uid == uid).first() + return user + except Exception as e: + logging.error(f"Error fetching user: {e}") + return None + + +# User CRUD operations + def get_user_by_email(db: Session, email: str): + return db.query(User).filter(User.email == email).first() + + def get_user_by_username(db: Session, username: str): + return db.query(User).filter(User.provider_username == username).first() + + def create_user(db: Session, user: User): + db.add(user) + db.commit() + db.refresh(user) + return user + + def update_user(db: Session, user_id: str, **kwargs): + db.query(User).filter(User.uid == user_id).update(kwargs) + db.commit() + + def delete_user(db: Session, user_id: str): + db.query(User).filter(User.uid == user_id).delete() + db.commit() + + + + + + diff --git a/app/modules/utils/APIRouter.py b/app/modules/utils/APIRouter.py new file mode 100644 index 00000000..806b2d2b --- /dev/null +++ b/app/modules/utils/APIRouter.py @@ -0,0 +1,27 @@ +from typing import Any, Callable + +from fastapi import APIRouter as FastAPIRouter +from fastapi.types import DecoratedCallable + + +class APIRouter(FastAPIRouter): + def api_route( + self, path: str, *, include_in_schema: bool = True, **kwargs: Any + ) -> Callable[[DecoratedCallable], DecoratedCallable]: + if path.endswith("/"): + path = path[:-1] + + add_path = super().api_route( + path, include_in_schema=include_in_schema, **kwargs + ) + + alternate_path = path + "/" + add_alternate_path = super().api_route( + alternate_path, include_in_schema=False, **kwargs + ) + + def decorator(func: DecoratedCallable) -> DecoratedCallable: + add_alternate_path(func) + return add_path(func) + + return decorator \ No newline at end of file diff --git a/app/modules/utils/ai_helper.py b/app/modules/utils/ai_helper.py new file mode 100644 index 00000000..6c5cd597 --- /dev/null +++ b/app/modules/utils/ai_helper.py @@ -0,0 +1,37 @@ +from portkey_ai import createHeaders, PORTKEY_GATEWAY_URL +from langchain_openai.chat_models import ChatOpenAI +import os +from app.modules.key_management.secret_manager import get_secret +from app.core.mongo_manager import MongoDBHelper + +class AIHelper: + @staticmethod + def get_llm_client(user_id, model_name): + provider_key = AIHelper.get_provider_key(user_id) + return AIHelper.create_client(provider_key["provider"], provider_key["key"], model_name, user_id) + + @staticmethod + def get_provider_key(customer_id): + if os.environ.get("isDevelopmentMode") == "enabled": + return {"provider": "openai", "key": os.environ.get("OPENAI_API_KEY")} + mongo_helper = MongoDBHelper() + preference = mongo_helper.get(customer_id, "preferences").get() + if preference.exists and preference.get("provider") == "openai": + return { + "provider": "openai", + "key": get_secret("openai", customer_id)["api_key"], + } + else: + return {"provider": "openai", "key": os.environ.get("OPENAI_API_KEY")} + + @staticmethod + def create_client(provider, key, model_name, user_id): + if provider == "openai": + PROVIDER_API_KEY = key + + if os.getenv("isDevelopmentMode") == "enabled": + return ChatOpenAI(api_key=PROVIDER_API_KEY, model=model_name) + else: + PORTKEY_API_KEY = os.environ.get("PORTKEY_API_KEY") + portkey_headers = createHeaders(api_key=PORTKEY_API_KEY, provider="openai", metadata={"_user": user_id, "environment": os.environ.get("ENV")}) + return ChatOpenAI(api_key=PROVIDER_API_KEY, model=model_name, base_url=PORTKEY_GATEWAY_URL, default_headers=portkey_headers) \ No newline at end of file diff --git a/app/modules/utils/dummy_setup.py b/app/modules/utils/dummy_setup.py index cccd6fa9..467d7950 100644 --- a/app/modules/utils/dummy_setup.py +++ b/app/modules/utils/dummy_setup.py @@ -3,6 +3,7 @@ from app.modules.projects.projects_model import Project from app.modules.users.user_model import User from sqlalchemy.sql import func +import logging class DummyDataSetup: def __init__(self): @@ -26,9 +27,9 @@ def setup_dummy_user(self): ) self.db.add(user) self.db.commit() - print(f"Created dummy user with uid: {user.uid}") + logging.info(f"Created dummy user with uid: {user.uid}") else: - print("Dummy user already exists") + logging.info("Dummy user already exists") finally: self.db.close() @@ -38,13 +39,11 @@ def setup_dummy_project(self): dummy_user = self.db.query(User).filter_by(uid=os.getenv("defaultUsername")).first() if dummy_user: # Check if the dummy project already exists - project_exists = self.db.query(Project).filter_by(directory="dummy_directory").first() + project_exists = self.db.query(Project).filter_by(repo_name="dummy_repo").first() if not project_exists: # Create a dummy project dummy_project = Project( - directory="dummy_directory", - is_default=True, - project_name="Dummy Project Created To Test AI Agent", + id="dummy_project_id", properties=b'{}', repo_name="dummy_repo", branch_name="main", @@ -53,14 +52,14 @@ def setup_dummy_project(self): commit_id="dummy_commit_id", is_deleted=False, updated_at=func.now(), - status="created" + status="ready" ) self.db.add(dummy_project) self.db.commit() - print(f"Created dummy project with id: {dummy_project.id}") + logging.info(f"Created dummy project with id: {dummy_project.id}") else: - print("Dummy project already exists") + logging.info("Dummy project already exists") else: - print("Dummy user not found, cannot create dummy project") + logging.info("Dummy user not found, cannot create dummy project") finally: self.db.close() diff --git a/app/modules/utils/firebase_setup.py b/app/modules/utils/firebase_setup.py new file mode 100644 index 00000000..9d95788f --- /dev/null +++ b/app/modules/utils/firebase_setup.py @@ -0,0 +1,31 @@ +import os +import firebase_admin +from firebase_admin import auth, credentials +import base64 +import json +import logging + +class FirebaseSetup: + + + @staticmethod + def firebase_init(): + service_account_base64 = os.getenv('FIREBASE_SERVICE_ACCOUNT') + current_file_path = os.path.dirname(os.path.abspath(__file__)) + parent_directory = os.path.abspath(os.path.join(current_file_path, *(['..'] * 3))) + cred = None + if service_account_base64: + try: + service_account_info = base64.b64decode(service_account_base64).decode('utf-8') + service_account_json = json.loads(service_account_info) + cred = credentials.Certificate(service_account_json) + logging.info("Loaded Firebase credentials from environment variable.") + except Exception as e: + logging.info(f"Error decoding Firebase service account from environment variable: {e}") + cred = credentials.Certificate(os.path.join(parent_directory, "firebase_service_account.json")) + logging.info("Loaded Firebase credentials from local file as fallback.") + else: + cred = credentials.Certificate(os.path.join(parent_directory, "firebase_service_account.json")) + logging.info("Loaded Firebase credentials from local file.") + + firebase_admin.initialize_app(cred) \ No newline at end of file diff --git a/app/modules/utils/model_helper.py b/app/modules/utils/model_helper.py new file mode 100644 index 00000000..c3f4a524 --- /dev/null +++ b/app/modules/utils/model_helper.py @@ -0,0 +1,34 @@ +from sqlalchemy.orm import class_mapper +from sqlalchemy.orm.exc import DetachedInstanceError + +class ModelHelper: + @staticmethod + def model_to_dict(model, max_depth=1, current_depth=0): + if model is None or current_depth > max_depth: + return None + + result = {} + + try: + mapper = class_mapper(model.__class__) + except: + # If it's not a SQLAlchemy model class, return the object as is + return model + + for key in mapper.column_attrs.keys(): + result[key] = getattr(model, key) + + # Handle relationships + for rel_name, rel_attr in mapper.relationships.items(): + try: + related_obj = getattr(model, rel_name) + if related_obj is not None: + if isinstance(related_obj, list): + result[rel_name] = [ModelHelper.model_to_dict(item, max_depth, current_depth + 1) for item in related_obj if item is not None] + else: + result[rel_name] = ModelHelper.model_to_dict(related_obj, max_depth, current_depth + 1) + except DetachedInstanceError: + # Skip this relationship if it's not loaded + pass + + return result \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 30f12927..ba7edc85 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,7 +4,17 @@ sqlalchemy alembic gunicorn python-dotenv +postgres +# psycopg2 psycopg2-binary +neo4j +tree-sitter<=0.21.3 +tree-sitter-languages +tqdm +grep-ast +pygments +networkx +blar-graph openai duckduckgo-search uuid6 @@ -12,11 +22,13 @@ aiohttp langchain langchain-openai langchain-community -agentops - -# Additional dependencies for specific tools pytrends wikipedia-api langgraph langchain-postgres -psycopg[binary,pool] \ No newline at end of file +pymongo +firebase_admin +portkey_ai +gitPython +PyGithub +google-cloud-secret-manager \ No newline at end of file