Skip to content

Commit

Permalink
Merge pull request #32 from kris-hansen/cors
Browse files Browse the repository at this point in the history
feat:  cors support
  • Loading branch information
kris-hansen authored Dec 29, 2024
2 parents 0543a45 + 60732cb commit 0ecfcce
Show file tree
Hide file tree
Showing 5 changed files with 265 additions and 31 deletions.
35 changes: 33 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -246,9 +246,27 @@ comanda server port 8080 # Set server port
comanda server datadir ./data # Set data directory
comanda server auth on # Enable authentication
comanda server auth off # Disable authentication
comanda server newtoken # Generate new bearer token
comanda server newtoken # Generate new bearer token
comanda server cors # Configure CORS settings
```

The server provides several configuration commands:

- `configure`: Interactive configuration for all server settings including port, data directory, authentication, and CORS
- `show`: Display current server configuration including CORS settings
- `port`: Set the server port
- `datadir`: Set the data directory for YAML files
- `auth`: Enable/disable authentication
- `newtoken`: Generate a new bearer token
- `cors`: Configure CORS settings interactively

The CORS configuration allows you to:
- Enable/disable CORS headers
- Set allowed origins (use * for all, or specify domains)
- Configure allowed HTTP methods
- Set allowed headers
- Define max age for preflight requests

The server configuration is stored in your `.env` file alongside provider and model settings:

```yaml
Expand All @@ -257,7 +275,20 @@ server:
data_dir: "examples" # Directory containing YAML files to process
bearer_token: "your-generated-token"
enabled: true # Whether authentication is required
```
cors:
enabled: true # Enable/disable CORS
allowed_origins: ["*"] # List of allowed origins, ["*"] for all
allowed_methods: ["GET", "POST", "PUT", "DELETE", "OPTIONS"] # List of allowed HTTP methods
allowed_headers: ["Authorization", "Content-Type"] # List of allowed headers
max_age: 3600 # Max age for preflight requests in seconds
```

The CORS configuration allows you to control Cross-Origin Resource Sharing settings:
- `enabled`: Enable or disable CORS headers (default: true)
- `allowed_origins`: List of origins allowed to access the API. Use `["*"]` to allow all origins, or specify domains like `["https://example.com"]`
- `allowed_methods`: List of HTTP methods allowed for cross-origin requests
- `allowed_headers`: List of headers allowed in requests
- `max_age`: How long browsers should cache preflight request results

To start the server:

Expand Down
111 changes: 111 additions & 0 deletions cmd/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,16 @@ var showServerCmd = &cobra.Command{
if server.BearerToken != "" {
fmt.Printf("Bearer Token: %s\n", server.BearerToken)
}

// Display CORS configuration
fmt.Println("\nCORS Configuration:")
fmt.Printf("Enabled: %v\n", server.CORS.Enabled)
if server.CORS.Enabled {
fmt.Printf("Allowed Origins: %s\n", strings.Join(server.CORS.AllowedOrigins, ", "))
fmt.Printf("Allowed Methods: %s\n", strings.Join(server.CORS.AllowedMethods, ", "))
fmt.Printf("Allowed Headers: %s\n", strings.Join(server.CORS.AllowedHeaders, ", "))
fmt.Printf("Max Age: %d seconds\n", server.CORS.MaxAge)
}
fmt.Println()
},
}
Expand Down Expand Up @@ -303,6 +313,106 @@ func configureServer(reader *bufio.Reader, envConfig *config.EnvConfig) error {
enableStr, _ := reader.ReadString('\n')
serverConfig.Enabled = strings.TrimSpace(strings.ToLower(enableStr)) == "y"

// Configure CORS settings
if err := configureCORS(reader, envConfig); err != nil {
return fmt.Errorf("error configuring CORS: %v", err)
}

envConfig.UpdateServerConfig(*serverConfig)
return nil
}

var corsCmd = &cobra.Command{
Use: "cors",
Short: "Configure CORS settings",
Long: `Configure Cross-Origin Resource Sharing (CORS) settings for the server`,
Run: func(cmd *cobra.Command, args []string) {
configPath := config.GetEnvPath()
envConfig, err := config.LoadEnvConfigWithPassword(configPath)
if err != nil {
fmt.Printf("Error loading configuration: %v\n", err)
return
}

reader := bufio.NewReader(os.Stdin)
if err := configureCORS(reader, envConfig); err != nil {
fmt.Printf("Error configuring CORS: %v\n", err)
return
}

if err := config.SaveEnvConfig(configPath, envConfig); err != nil {
fmt.Printf("Error saving configuration: %v\n", err)
return
}

fmt.Println("CORS configuration saved successfully!")
},
}

// configureCORS handles the interactive CORS configuration
func configureCORS(reader *bufio.Reader, envConfig *config.EnvConfig) error {
serverConfig := envConfig.GetServerConfig()

// Prompt for CORS enable/disable
fmt.Print("Enable CORS? (y/n): ")
enableStr, _ := reader.ReadString('\n')
serverConfig.CORS.Enabled = strings.TrimSpace(strings.ToLower(enableStr)) == "y"

if serverConfig.CORS.Enabled {
// Prompt for allowed origins
fmt.Print("Enter allowed origins (comma-separated, * for all, default: *): ")
originsStr, _ := reader.ReadString('\n')
originsStr = strings.TrimSpace(originsStr)
if originsStr != "" && originsStr != "*" {
serverConfig.CORS.AllowedOrigins = strings.Split(originsStr, ",")
for i := range serverConfig.CORS.AllowedOrigins {
serverConfig.CORS.AllowedOrigins[i] = strings.TrimSpace(serverConfig.CORS.AllowedOrigins[i])
}
} else {
serverConfig.CORS.AllowedOrigins = []string{"*"}
}

// Prompt for allowed methods
fmt.Print("Enter allowed methods (comma-separated, default: GET,POST,PUT,DELETE,OPTIONS): ")
methodsStr, _ := reader.ReadString('\n')
methodsStr = strings.TrimSpace(methodsStr)
if methodsStr != "" {
serverConfig.CORS.AllowedMethods = strings.Split(methodsStr, ",")
for i := range serverConfig.CORS.AllowedMethods {
serverConfig.CORS.AllowedMethods[i] = strings.TrimSpace(serverConfig.CORS.AllowedMethods[i])
}
} else {
serverConfig.CORS.AllowedMethods = []string{"GET", "POST", "PUT", "DELETE", "OPTIONS"}
}

// Prompt for allowed headers
fmt.Print("Enter allowed headers (comma-separated, default: Authorization,Content-Type): ")
headersStr, _ := reader.ReadString('\n')
headersStr = strings.TrimSpace(headersStr)
if headersStr != "" {
serverConfig.CORS.AllowedHeaders = strings.Split(headersStr, ",")
for i := range serverConfig.CORS.AllowedHeaders {
serverConfig.CORS.AllowedHeaders[i] = strings.TrimSpace(serverConfig.CORS.AllowedHeaders[i])
}
} else {
serverConfig.CORS.AllowedHeaders = []string{"Authorization", "Content-Type"}
}

// Prompt for max age
fmt.Print("Enter max age in seconds (default: 3600): ")
maxAgeStr, _ := reader.ReadString('\n')
maxAgeStr = strings.TrimSpace(maxAgeStr)
if maxAgeStr != "" {
maxAge, err := strconv.Atoi(maxAgeStr)
if err != nil {
return fmt.Errorf("invalid max age: %v", err)
}
serverConfig.CORS.MaxAge = maxAge
} else {
serverConfig.CORS.MaxAge = 3600
}
}

envConfig.UpdateServerConfig(*serverConfig)
return nil
}
Expand All @@ -314,5 +424,6 @@ func init() {
serverCmd.AddCommand(updateDataDirCmd)
serverCmd.AddCommand(toggleAuthCmd)
serverCmd.AddCommand(newTokenCmd)
serverCmd.AddCommand(corsCmd)
rootCmd.AddCommand(serverCmd)
}
26 changes: 22 additions & 4 deletions utils/config/env.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,12 +56,22 @@ type Provider struct {
Models []Model `yaml:"models"`
}

// CORSConfig represents CORS configuration options
type CORSConfig struct {
Enabled bool `yaml:"enabled"`
AllowedOrigins []string `yaml:"allowed_origins,omitempty"`
AllowedMethods []string `yaml:"allowed_methods,omitempty"`
AllowedHeaders []string `yaml:"allowed_headers,omitempty"`
MaxAge int `yaml:"max_age,omitempty"`
}

// ServerConfig represents the server configuration
type ServerConfig struct {
Port int `yaml:"port"`
BearerToken string `yaml:"bearer_token,omitempty"`
Enabled bool `yaml:"enabled"`
DataDir string `yaml:"data_dir"`
Port int `yaml:"port"`
BearerToken string `yaml:"bearer_token,omitempty"`
Enabled bool `yaml:"enabled"`
DataDir string `yaml:"data_dir"`
CORS CORSConfig `yaml:"cors"`
}

// EnvConfig represents the complete environment configuration
Expand Down Expand Up @@ -349,6 +359,13 @@ func (c *EnvConfig) GetServerConfig() *ServerConfig {
Port: 8080,
Enabled: false,
DataDir: "data",
CORS: CORSConfig{
Enabled: true,
AllowedOrigins: []string{"*"},
AllowedMethods: []string{"GET", "POST", "PUT", "DELETE", "OPTIONS"},
AllowedHeaders: []string{"Authorization", "Content-Type"},
MaxAge: 3600,
},
}
}
return c.Server
Expand All @@ -363,6 +380,7 @@ func (c *EnvConfig) UpdateServerConfig(config ServerConfig) {
c.Server.BearerToken = config.BearerToken
c.Server.Enabled = config.Enabled
c.Server.DataDir = config.DataDir
c.Server.CORS = config.CORS
}

// GetProviderConfig retrieves configuration for a specific provider
Expand Down
106 changes: 85 additions & 21 deletions utils/server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,66 @@ func (s *Server) validatePath(path string) (string, error) {
return fullPath, nil
}

// handleCORS adds CORS headers based on configuration
func (s *Server) handleCORS(w http.ResponseWriter) {
if !s.config.CORS.Enabled {
return
}

// Set allowed origins
if len(s.config.CORS.AllowedOrigins) > 0 {
origin := strings.Join(s.config.CORS.AllowedOrigins, ", ")
w.Header().Set("Access-Control-Allow-Origin", origin)
} else {
w.Header().Set("Access-Control-Allow-Origin", "*")
}

// Set allowed methods
if len(s.config.CORS.AllowedMethods) > 0 {
methods := strings.Join(s.config.CORS.AllowedMethods, ", ")
w.Header().Set("Access-Control-Allow-Methods", methods)
} else {
w.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS")
}

// Set allowed headers
if len(s.config.CORS.AllowedHeaders) > 0 {
headers := strings.Join(s.config.CORS.AllowedHeaders, ", ")
w.Header().Set("Access-Control-Allow-Headers", headers)
} else {
w.Header().Set("Access-Control-Allow-Headers", "Authorization, Content-Type")
}

// Set max age
if s.config.CORS.MaxAge > 0 {
w.Header().Set("Access-Control-Max-Age", fmt.Sprintf("%d", s.config.CORS.MaxAge))
} else {
w.Header().Set("Access-Control-Max-Age", "3600")
}
}

// combinedMiddleware applies middleware in the correct order based on request method
func (s *Server) combinedMiddleware(handler http.HandlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
// Always set CORS headers first
s.handleCORS(w)

// Handle OPTIONS requests immediately
if r.Method == "OPTIONS" {
w.WriteHeader(http.StatusOK)
return
}

// For non-OPTIONS requests, proceed with logging and auth
logRequest(func(w http.ResponseWriter, r *http.Request) {
if !checkAuth(s.config, w, r) {
return
}
handler(w, r)
})(w, r)
}
}

// New creates a new HTTP server with the given configuration
func New(envConfig *config.EnvConfig) (*http.Server, error) {
// Get server configuration
Expand All @@ -100,12 +160,19 @@ func New(envConfig *config.EnvConfig) (*http.Server, error) {
return nil, fmt.Errorf("error creating data directory: %v", err)
}

// Convert config.ServerConfig to our internal ServerConfig
// Convert config.ServerConfig to our internal ServerConfig with default CORS settings
srvConfig := &ServerConfig{
Port: serverConfig.Port,
DataDir: serverConfig.DataDir,
BearerToken: serverConfig.BearerToken,
Enabled: serverConfig.Enabled,
CORS: CORSConfig{
Enabled: true,
AllowedOrigins: []string{"*"},
AllowedMethods: []string{"GET", "POST", "PUT", "DELETE", "OPTIONS"},
AllowedHeaders: []string{"Authorization", "Content-Type"},
MaxAge: 3600,
},
}

s := &Server{
Expand All @@ -130,24 +197,24 @@ func New(envConfig *config.EnvConfig) (*http.Server, error) {

// routes sets up the server routes
func (s *Server) routes() {
// Health check endpoint
s.mux.HandleFunc("/health", logRequest(func(w http.ResponseWriter, r *http.Request) {
// Health check endpoint - no auth required
s.mux.HandleFunc("/health", s.combinedMiddleware(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(HealthResponse{
Status: "ok",
Timestamp: time.Now().Format(time.RFC3339),
})
}))

// File operations
s.mux.HandleFunc("/list", logRequest(s.handleListFiles))
s.mux.HandleFunc("/files", logRequest(s.handleFileOperation))
s.mux.HandleFunc("/files/bulk", logRequest(s.handleBulkFileOperation))
s.mux.HandleFunc("/files/backup", logRequest(s.handleFileBackup))
s.mux.HandleFunc("/files/restore", logRequest(s.handleFileRestore))
// File operations - require auth
s.mux.HandleFunc("/list", s.combinedMiddleware(s.handleListFiles))
s.mux.HandleFunc("/files", s.combinedMiddleware(s.handleFileOperation))
s.mux.HandleFunc("/files/bulk", s.combinedMiddleware(s.handleBulkFileOperation))
s.mux.HandleFunc("/files/backup", s.combinedMiddleware(s.handleFileBackup))
s.mux.HandleFunc("/files/restore", s.combinedMiddleware(s.handleFileRestore))

// Provider operations
s.mux.HandleFunc("/providers", logRequest(func(w http.ResponseWriter, r *http.Request) {
// Provider operations - require auth
s.mux.HandleFunc("/providers", s.combinedMiddleware(func(w http.ResponseWriter, r *http.Request) {
switch r.Method {
case http.MethodGet:
s.handleGetProviders(w, r)
Expand All @@ -163,37 +230,34 @@ func (s *Server) routes() {
}
}))

// Provider validation
s.mux.HandleFunc("/providers/validate", logRequest(func(w http.ResponseWriter, r *http.Request) {
// Provider validation - requires auth
s.mux.HandleFunc("/providers/validate", s.combinedMiddleware(func(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
w.WriteHeader(http.StatusMethodNotAllowed)
return
}
s.handleValidateProvider(w, r)
}))

// Environment operations
s.mux.HandleFunc("/env/encrypt", logRequest(func(w http.ResponseWriter, r *http.Request) {
// Environment operations - require auth
s.mux.HandleFunc("/env/encrypt", s.combinedMiddleware(func(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
w.WriteHeader(http.StatusMethodNotAllowed)
return
}
s.handleEncryptEnv(w, r)
}))

s.mux.HandleFunc("/env/decrypt", logRequest(func(w http.ResponseWriter, r *http.Request) {
s.mux.HandleFunc("/env/decrypt", s.combinedMiddleware(func(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
w.WriteHeader(http.StatusMethodNotAllowed)
return
}
s.handleDecryptEnv(w, r)
}))

// Process endpoint
s.mux.HandleFunc("/process", logRequest(func(w http.ResponseWriter, r *http.Request) {
if !checkAuth(s.config, w, r) {
return
}
// Process endpoint - requires auth
s.mux.HandleFunc("/process", s.combinedMiddleware(func(w http.ResponseWriter, r *http.Request) {
handleProcess(w, r, s.config, s.envConfig)
}))
}
Expand Down
Loading

0 comments on commit 0ecfcce

Please sign in to comment.