From 4ca6a0700deec3fde99c43b24f2b248d24a0a213 Mon Sep 17 00:00:00 2001 From: Jayson Grace Date: Thu, 27 Apr 2023 12:55:26 -0600 Subject: [PATCH] PR feedback addressed - Decoupled bash reliance for template creation - Tests updated and working - Add ExpandHomeDir() and test - Updated TestPathExistsInInventory to be more realistic - Moved bash templates to their own directory - Fixed pkg tests to specifications in PR - Updated TemplateExists to support multiple inventoryPaths --- cmd/new.go | 87 ++++++----- cmd/new_test.go | 196 +++++++++++++++++++----- cmd/root.go | 48 ++---- cmd/run.go | 2 +- pkg/files/file.go | 123 +++++++++++---- pkg/files/file_test.go | 199 +++++++++++++++++++++---- pkg/files/yaml.go | 4 +- pkg/files/yaml_test.go | 14 +- templates/{ => bash}/README.md.tmpl | 0 templates/{ => bash}/bashTTP.sh.tmpl | 0 templates/{ => bash}/bashTTP.yaml.tmpl | 0 11 files changed, 503 insertions(+), 170 deletions(-) rename templates/{ => bash}/README.md.tmpl (100%) rename templates/{ => bash}/bashTTP.sh.tmpl (100%) rename templates/{ => bash}/bashTTP.yaml.tmpl (100%) diff --git a/cmd/new.go b/cmd/new.go index ec5f8c7e..2981c2f6 100644 --- a/cmd/new.go +++ b/cmd/new.go @@ -20,6 +20,7 @@ THE SOFTWARE. package cmd import ( + "fmt" "os" "path/filepath" "text/template" @@ -28,6 +29,7 @@ import ( "github.com/facebookincubator/ttpforge/pkg/files" "github.com/facebookincubator/ttpforge/pkg/logging" "github.com/spf13/cobra" + "github.com/spf13/viper" "go.uber.org/zap" ) @@ -37,7 +39,7 @@ func init() { } var newTTPInput NewTTPInput -var dirPath string +var ttpDir string // NewTTPInput contains the inputs required to create a new TTP from a template. type NewTTPInput struct { @@ -72,24 +74,45 @@ func NewTTPBuilderCmd() *cobra.Command { `, PreRunE: func(cmd *cobra.Command, args []string) error { - if err := validateTemplateFlag(cmd); err != nil { - return err - } - requiredFlags := []string{ "template", "path", "ttp-type", } - return checkRequiredFlags(cmd, requiredFlags) + for _, flag := range requiredFlags { + if err := cmd.MarkFlagRequired(flag); err != nil { + return err + } + } + + return nil }, Run: func(cmd *cobra.Command, args []string) { + inventoryPaths := viper.GetStringSlice("inventory") + + var templatePath string + + relTmplPath := filepath.Join("templates", newTTPInput.Template) + + // Iterate through inventory paths and find the first matching template + for _, invPath := range inventoryPaths { + invPath = files.ExpandHomeDir(invPath) + absTemplatePath := filepath.Join(invPath, "..", relTmplPath) + templateFound, err := files.PathExistsInInventory(relTmplPath, []string{invPath}) + cobra.CheckErr(err) + + if templateFound { + templatePath = absTemplatePath + break + } + } + // Create the filepath for the input TTP if it doesn't already exist. - bashTTPFile := newTTPInput.Path + ttpFile := newTTPInput.Path - dirPath = filepath.Dir(bashTTPFile) - if err := files.CreateDirIfNotExists(dirPath); err != nil { + ttpDir = filepath.Dir(ttpFile) + if err := files.CreateDirIfNotExists(ttpDir); err != nil { cobra.CheckErr(err) } @@ -98,9 +121,10 @@ func NewTTPBuilderCmd() *cobra.Command { logging.Logger.Sugar().Errorw("failed to create TTP with:", newTTPInput, zap.Error(err)) cobra.CheckErr(err) } + // Populate templated TTP file tmpl := template.Must( - template.ParseFiles(filepath.Join("templates", "bashTTP.yaml.tmpl"))) + template.ParseFiles(ttpFile)) yamlF, err := os.Create(newTTPInput.Path) cobra.CheckErr(err) @@ -110,10 +134,10 @@ func NewTTPBuilderCmd() *cobra.Command { cobra.CheckErr(err) } - // Create README from template - readme := filepath.Join(dirPath, "README.md") + // Create README from templae + readme := filepath.Join(ttpDir, "README.md") tmpl = template.Must( - template.ParseFiles(filepath.Join("templates", "README.md.tmpl"))) + template.ParseFiles(filepath.Join(templatePath, "README.md.tmpl"))) readmeF, err := os.Create(readme) cobra.CheckErr(err) @@ -123,17 +147,19 @@ func NewTTPBuilderCmd() *cobra.Command { cobra.CheckErr(err) } - // Create templated bash script (if applicable) + // Create templated file-based TTP (if applicable) if newTTPInput.TTPType == "file" { - tmpl = template.Must( - template.ParseFiles(filepath.Join("templates", "bashTTP.sh.tmpl"))) + if newTTPInput.Template == "bash" { + tmpl = template.Must( + template.ParseFiles(filepath.Join(templatePath, "bashTTP.sh.tmpl"))) - bashScriptF, err := os.Create(filepath.Join(dirPath, "bashTTP.sh")) - cobra.CheckErr(err) - defer bashScriptF.Close() - - if err := tmpl.Execute(bashScriptF, ttp); err != nil { + bashScriptF, err := os.Create(filepath.Join(ttpDir, "bashTTP.sh")) cobra.CheckErr(err) + defer bashScriptF.Close() + + if err := tmpl.Execute(bashScriptF, ttp); err != nil { + cobra.CheckErr(err) + } } } }, @@ -162,7 +188,7 @@ func createTTP() (*blocks.TTP, error) { if newTTPInput.TTPType == "file" { step = blocks.NewFileStep() step.(*blocks.FileStep).Act.Name = "example_file_step" - step.(*blocks.FileStep).FilePath = filepath.Join(dirPath, "bashTTP.sh") + step.(*blocks.FileStep).FilePath = filepath.Join(ttpDir, fmt.Sprintf("%sTTP.sh", newTTPInput.Template)) step.(*blocks.FileStep).Args = newTTPInput.Args } else { step = blocks.NewBasicStep() @@ -196,20 +222,3 @@ func createTTP() (*blocks.TTP, error) { return ttp, nil } - -func validateTemplateFlag(cmd *cobra.Command) error { - templateName, err := cmd.Flags().GetString("template") - if err != nil { - return err - } - exists, err := files.TemplateExists(templateName) - if err != nil { - logging.Logger.Sugar().Errorw("unsupported template:", templateName, zap.Error(err)) - return err - } - if !exists { - logging.Logger.Sugar().Errorw("template not found:", templateName, zap.Error(err)) - return err - } - return nil -} diff --git a/cmd/new_test.go b/cmd/new_test.go index cb0680b3..b1025287 100644 --- a/cmd/new_test.go +++ b/cmd/new_test.go @@ -22,10 +22,9 @@ package cmd_test import ( "bytes" "fmt" + "io" "os" - "path" "path/filepath" - "runtime" "testing" "github.com/facebookincubator/ttpforge/cmd" @@ -35,25 +34,127 @@ import ( "github.com/stretchr/testify/require" ) -var absConfigPath string +func createTestInventory(t *testing.T, dir string) { + t.Helper() -func init() { - // Find the absolute path to the config.yaml file - _, currentFilePath, _, _ := runtime.Caller(0) - absConfigPath = filepath.Join(path.Dir(currentFilePath), "..", "config.yaml") + lateralMovementDir := filepath.Join(dir, "lateral-movement", "ssh") + if err := os.MkdirAll(lateralMovementDir, 0755); err != nil { + t.Fatalf("failed to create lateral movement dir: %v", err) + } + + privEscalationDir := filepath.Join(dir, "privilege-escalation", "credential-theft", "hello-world") + if err := os.MkdirAll(privEscalationDir, 0755); err != nil { + t.Fatalf("failed to create privilege escalation dir: %v", err) + } + + testFiles := []struct { + path string + contents string + }{ + { + path: filepath.Join(lateralMovementDir, "rogue-ssh-key.yaml"), + contents: fmt.Sprintln("---\nname: test-rogue-ssh-key-contents"), + }, + { + path: filepath.Join(privEscalationDir, "priv-esc.yaml"), + contents: fmt.Sprintln("---\nname: test-priv-esc-contents"), + }, + } + + for _, file := range testFiles { + f, err := os.Create(file.path) + if err != nil { + t.Fatalf("failed to create test file: %v", err) + } + if _, err := io.WriteString(f, file.contents); err != nil { + t.Fatalf("failed to write to test file: %v", err) + } + f.Close() + } +} + +func createBashTestTemplates(t *testing.T, dir string) { + t.Helper() + + templateDir := filepath.Join(dir, "templates") + if err := os.MkdirAll(templateDir, 0755); err != nil { + t.Fatalf("failed to create templates directory: %v", err) + } + + // Create the "bash" directory inside the "templates" directory + bashDir := filepath.Join(templateDir, "bash") + if err := os.MkdirAll(bashDir, 0755); err != nil { + t.Fatalf("failed to create bash directory: %v", err) + } + + // Create basic TTP template + basicTemplateFile, err := os.Create(filepath.Join(bashDir, "bashTTP.yaml.tmpl")) + if err != nil { + t.Fatalf("failed to create test template: %v", err) + } + if _, err := io.WriteString(basicTemplateFile, "test basic template content"); err != nil { + t.Fatalf("failed to write to test template: %v", err) + } + defer basicTemplateFile.Close() - // Change into the same directory as the config (repo root). - if err := os.Chdir(filepath.Dir(absConfigPath)); err != nil { - panic(err) + // Create README template + readmeTmpl := "# This is a test" + + readmeTemplateFile, err := os.Create(filepath.Join(bashDir, "README.md.tmpl")) + if err != nil { + t.Fatalf("failed to create test template: %v", err) + } + if _, err := io.WriteString(readmeTemplateFile, readmeTmpl); err != nil { + t.Fatalf("failed to write to test template: %v", err) } + defer readmeTemplateFile.Close() + + // Create file TTP template + bashScriptTemplateFile, err := os.Create(filepath.Join(bashDir, "bashTTP.sh.tmpl")) + if err != nil { + t.Fatalf("failed to create test template: %v", err) + } + if _, err := io.WriteString(bashScriptTemplateFile, "test file template content"); err != nil { + t.Fatalf("failed to write to test template: %v", err) + } + defer bashScriptTemplateFile.Close() } func TestCreateAndRunTTP(t *testing.T) { - newTTPBuilderCmd := cmd.NewTTPBuilderCmd() + // Create a temporary file + testDir, err := os.MkdirTemp("", "cmd-new-test") + assert.NoError(t, err, "failed to create temporary directory") + // Clean up the temporary directory + defer os.RemoveAll(testDir) + + createTestInventory(t, testDir) + createBashTestTemplates(t, testDir) + + // Create ttp dir + ttpDir := filepath.Join(testDir, "ttps") + if err := os.MkdirAll(ttpDir, 0755); err != nil { + t.Fatalf("failed to create ttps directory: %v", err) + } - basicTestPath := filepath.Join("ttps", "test", "testBasicTTP.yaml") - fileTestPath := filepath.Join("ttps", "test", "testFileTTP.yaml") + // config for the test + testConfigYAML := `--- +inventory: + - ` + ttpDir + ` +logfile: "" +nocolor: false +stacktrace: false +verbose: false +` + // Write the config to a temporary file + testConfigYAMLPath := filepath.Join(testDir, "config.yaml") + err = os.WriteFile(testConfigYAMLPath, []byte(testConfigYAML), 0644) + assert.NoError(t, err, "failed to write the temporary YAML file") + + basicTestPath := filepath.Join(ttpDir, "basicTest", "testBasicTTP.yaml") + fileTestPath := filepath.Join(ttpDir, "fileTest", "testFileTTP.yaml") + + newTTPBuilderCmd := cmd.NewTTPBuilderCmd() testCases := []struct { name string setFlags func() @@ -62,23 +163,10 @@ func TestCreateAndRunTTP(t *testing.T) { expectError bool expectedErrorMsg string }{ - { - name: "All required flags set", - setFlags: func() { - _ = newTTPBuilderCmd.Flags().Set("config", absConfigPath) - _ = newTTPBuilderCmd.Flags().Set("path", basicTestPath) - _ = newTTPBuilderCmd.Flags().Set("template", "bash") - _ = newTTPBuilderCmd.Flags().Set("ttp-type", "file") - _ = newTTPBuilderCmd.Flags().Set("args", "arg1,arg2,arg3") - _ = newTTPBuilderCmd.Flags().Set("cleanup", "true") - _ = newTTPBuilderCmd.Flags().Set("env", "EXAMPLE_ENV_VAR=example_value") - }, - expectError: false, - }, { name: "Create basic bash TTP", setFlags: func() { - _ = newTTPBuilderCmd.Flags().Set("config", absConfigPath) + _ = newTTPBuilderCmd.Flags().Set("config", testConfigYAMLPath) _ = newTTPBuilderCmd.Flags().Set("path", basicTestPath) _ = newTTPBuilderCmd.Flags().Set("template", "bash") _ = newTTPBuilderCmd.Flags().Set("ttp-type", "basic") @@ -91,11 +179,12 @@ func TestCreateAndRunTTP(t *testing.T) { { name: "Create file-based bash TTP", setFlags: func() { - _ = newTTPBuilderCmd.Flags().Set("config", absConfigPath) - _ = newTTPBuilderCmd.Flags().Set("path", basicTestPath) + _ = newTTPBuilderCmd.Flags().Set("config", testConfigYAMLPath) + _ = newTTPBuilderCmd.Flags().Set("path", fileTestPath) _ = newTTPBuilderCmd.Flags().Set("template", "bash") _ = newTTPBuilderCmd.Flags().Set("ttp-type", "file") _ = newTTPBuilderCmd.Flags().Set("cleanup", "true") + _ = newTTPBuilderCmd.Flags().Set("env", "EXAMPLE_ENV_VAR=example_value") }, expected: fileTestPath, }, @@ -103,6 +192,40 @@ func TestCreateAndRunTTP(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { + // Set flags for the test case + tc.setFlags() + + // Update tc.input.TTPType value + ttpTypeFlag, err := newTTPBuilderCmd.Flags().GetString("ttp-type") + if err != nil { + t.Fatalf("failed to get ttp-type flag: %v", err) + } + tc.input.TTPType = ttpTypeFlag + + if err := os.Chdir(filepath.Dir(testConfigYAMLPath)); err != nil { + t.Fatalf("failed to change into test directory: %v", err) + } + // Set filepath for current test TTP + ttpPath := basicTestPath + if tc.input.TTPType == "file" { + ttpPath = fileTestPath + } + + // Create the test TTP directory if it doesn't already exist + if err := os.MkdirAll(filepath.Dir(ttpPath), 0755); err != nil { + t.Fatalf("failed to create ttps directory: %v", err) + } + + // Create the test TTP file + ttpFile, err := os.Create(ttpPath) + if err != nil { + t.Fatalf("failed to create test ttp: %v", err) + } + if _, err := io.WriteString(ttpFile, fmt.Sprintln("---\nname: test-ttp-contents")); err != nil { + t.Fatalf("failed to write to test ttp: %v", err) + } + defer ttpFile.Close() + // Reset flags newTTPBuilderCmd.Flags().VisitAll(func(flag *pflag.Flag) { _ = newTTPBuilderCmd.Flags().Set(flag.Name, "") @@ -111,16 +234,21 @@ func TestCreateAndRunTTP(t *testing.T) { // Set flags for the test case tc.setFlags() - // Call ExecuteContext with the custom context - err := newTTPBuilderCmd.Execute() + err = newTTPBuilderCmd.Execute() if tc.expectError { require.Error(t, err) assert.Contains(t, err.Error(), tc.expectedErrorMsg) } else { require.NoError(t, err) - // Ensure we are able to read ttps from the TTP directory - _, err = os.Stat(basicTestPath) - assert.NoError(t, err, "the test directory should exist") + if tc.input.TTPType == "basic" { + _, err = os.Stat(basicTestPath) + assert.NoError(t, err, "the test directory should exist") + } else if tc.input.TTPType == "file" { + _, err = os.Stat(fileTestPath) + assert.NoError(t, err, "the test directory should exist") + } else { + t.Fatal("Invalid TTPType provided") + } } // Check if the bash script file was created (for file TTP type) diff --git a/cmd/root.go b/cmd/root.go index b9be6a31..97bfbc8d 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -20,7 +20,6 @@ THE SOFTWARE. package cmd import ( - "fmt" "os" "strconv" @@ -47,7 +46,8 @@ type Config struct { var ( // Logger is used to facilitate logging throughout TTPForge. Logger *zap.Logger - conf = &Config{} + // Conf refers to the configuration used throughout TTPForge. + Conf = &Config{} rootCmd = &cobra.Command{ Use: "ttpforge", @@ -57,9 +57,9 @@ TTPForge is a Purple Team engagement tool to execute Tactics, Techniques, and Pr `, TraverseChildren: true, PersistentPostRun: func(cmd *cobra.Command, args []string) { - if conf.saveConfig != "" { + if Conf.saveConfig != "" { // https://github.com/facebookincubator/ttpforge/issues/4 - if err := WriteConfigToFile(conf.saveConfig); err != nil { + if err := WriteConfigToFile(Conf.saveConfig); err != nil { logging.Logger.Error("failed to write config values", zap.Error(err)) } } @@ -81,7 +81,7 @@ TTPForge is a Purple Team engagement tool to execute Tactics, Techniques, and Pr // - error: An error object if any issues occur during the marshaling or file // writing process, otherwise nil. func WriteConfigToFile(filepath string) error { - yamlBytes, err := yaml.Marshal(conf) + yamlBytes, err := yaml.Marshal(Conf) if err != nil { return err } @@ -99,12 +99,12 @@ func init() { Logger = logging.Logger cobra.OnInitialize(initConfig) - // These flags are set using Cobra only, so we populate the conf.* variables directly + // These flags are set using Cobra only, so we populate the Conf.* variables directly // reference the unset values in the struct Config above. - rootCmd.PersistentFlags().StringVarP(&conf.cfgFile, "config", "c", "config.yaml", "Config file (default is config.yaml)") - rootCmd.PersistentFlags().StringVar(&conf.saveConfig, "save-config", "", "Writes values used in execution to the specified location") - rootCmd.PersistentFlags().BoolVar(&conf.StackTrace, "stacktrace", false, "Show stacktrace when logging error") - rootCmd.PersistentFlags().StringArrayVar(&conf.InventoryPath, "inventory", []string{"."}, "list of paths to search for ttps") + rootCmd.PersistentFlags().StringVarP(&Conf.cfgFile, "config", "c", "config.yaml", "Config file (default is config.yaml)") + rootCmd.PersistentFlags().StringVar(&Conf.saveConfig, "save-config", "", "Writes values used in execution to the specified location") + rootCmd.PersistentFlags().BoolVar(&Conf.StackTrace, "stacktrace", false, "Show stacktrace when logging error") + rootCmd.PersistentFlags().StringArrayVar(&Conf.InventoryPath, "inventory", []string{"."}, "list of paths to search for ttps") // Notice here that the values from the command line are not populated in this instance. // This is because we are using viper in addition to cobra to manage these values - // Cobra will look for these values on the command line. If the values are not present, @@ -140,9 +140,9 @@ func init() { // initConfig reads in config file and ENV variables if set. func initConfig() { - if conf.cfgFile != "" { + if Conf.cfgFile != "" { // Use config file from the flag. - viper.SetConfigFile(conf.cfgFile) + viper.SetConfigFile(Conf.cfgFile) } else { // Search config in current directory with name ".cobra" (without extension). viper.AddConfigPath(".") @@ -162,33 +162,13 @@ func initConfig() { cobra.CheckErr(err) } - if err := viper.Unmarshal(conf, func(config *mapstructure.DecoderConfig) { + if err := viper.Unmarshal(Conf, func(config *mapstructure.DecoderConfig) { config.IgnoreUntaggedFields = true }); err != nil { cobra.CheckErr(err) } - err := logging.InitLog(conf.NoColor, conf.Logfile, conf.Verbose, conf.StackTrace) + err := logging.InitLog(Conf.NoColor, Conf.Logfile, Conf.Verbose, Conf.StackTrace) cobra.CheckErr(err) Logger = logging.Logger } - -func getStringFlagOrDefault(cmd *cobra.Command, flag string) *string { - value, _ := cmd.Flags().GetString(flag) - if value == "" { - viperValue := viper.GetString(flag) - return &viperValue - } - return &value -} - -func checkRequiredFlags(cmd *cobra.Command, requiredFlags []string) error { - for _, flag := range requiredFlags { - value := getStringFlagOrDefault(cmd, flag) - if *value == "" { - return fmt.Errorf("required flag '%s' not set", flag) - } - } - - return nil -} diff --git a/cmd/run.go b/cmd/run.go index bd0528a4..d3d264fe 100755 --- a/cmd/run.go +++ b/cmd/run.go @@ -37,7 +37,7 @@ func RunTTPCmd() *cobra.Command { Short: "Run the forgery using the file specified in args.", Args: cobra.ExactArgs(1), Run: func(cmd *cobra.Command, args []string) { - if _, err := files.ExecuteYAML(args[0]); err != nil { + if _, err := files.ExecuteYAML(args[0], Conf.InventoryPath); err != nil { Logger.Sugar().Errorw("failed to execute TTP", zap.Error(err)) } }, diff --git a/pkg/files/file.go b/pkg/files/file.go index 6a22deaa..40f207dc 100755 --- a/pkg/files/file.go +++ b/pkg/files/file.go @@ -20,11 +20,11 @@ THE SOFTWARE. package files import ( + "errors" "fmt" + "io/fs" "os" "path/filepath" - - "github.com/spf13/viper" ) // CreateDirIfNotExists checks if a directory exists at the given path and creates it if it does not exist. @@ -37,17 +37,21 @@ import ( // Returns: // // error: An error if the directory could not be created. +// +// Example: +// +// dirPath := "path/to/directory" +// err := CreateDirIfNotExists(dirPath) +// +// if err != nil { +// log.Fatalf("failed to create directory: %v", err) +// } func CreateDirIfNotExists(path string) error { fileInfo, err := os.Stat(path) - if err != nil { - if os.IsNotExist(err) { - // Create the directory if it does not exist - err = os.MkdirAll(path, 0755) - if err != nil { - return err - } - } else { + if errors.Is(err, fs.ErrNotExist) { + // Create the directory if it does not exist + if err := os.MkdirAll(path, 0755); err != nil { return err } } else { @@ -60,12 +64,54 @@ func CreateDirIfNotExists(path string) error { return nil } +// ExpandHomeDir expands the tilde character in a path to the user's home directory. +// The function takes a string representing a path and checks if the first character is a tilde (~). +// If it is, the function replaces the tilde with the user's home directory. The path is returned +// unchanged if it does not start with a tilde or if there's an error retrieving the user's home +// directory. +// +// Borrowed from https://github.com/l50/goutils/blob/e91b7c4e18e23c53e35d04fa7961a5a14ca8ef39/fileutils.go#L283-L318 +// +// Parameters: +// +// path: The string containing a path that may start with a tilde (~) character. +// +// Returns: +// +// string: The expanded path with the tilde replaced by the user's home directory, or the +// +// original path if it does not start with a tilde or there's an error retrieving +// the user's home directory. +// +// Example: +// +// pathWithTilde := "~/Documents/myfile.txt" +// expandedPath := ExpandHomeDir(pathWithTilde) +// log.Printf("Expanded path: %s", expandedPath) +func ExpandHomeDir(path string) string { + if len(path) == 0 || path[0] != '~' { + return path + } + + homeDir, err := os.UserHomeDir() + if err != nil { + return path + } + + if len(path) == 1 || path[1] == '/' { + return filepath.Join(homeDir, path[1:]) + } + + return filepath.Join(homeDir, path[1:]) +} + // PathExistsInInventory checks if a relative file path exists in any of the inventory directories specified in the -// configuration file. If the file is found in any of the inventory directories, it returns true, otherwise, it returns false. +// inventoryPaths parameter. If the file is found in any of the inventory directories, it returns true, otherwise, it returns false. // // Parameters: // // relPath: A string representing the relative path of the file to search for in the inventory directories. +// inventoryPaths: A []string containing the inventory directory paths to search. // // Returns: // @@ -75,7 +121,8 @@ func CreateDirIfNotExists(path string) error { // Example: // // relFilePath := "templates/exampleTTP.yaml.tmpl" -// exists, err := PathExistsInInventory(relFilePath) +// inventoryPaths := []string{"path/to/inventory1", "path/to/inventory2"} +// exists, err := PathExistsInInventory(relFilePath, inventoryPaths) // // if err != nil { // log.Fatalf("failed to check file existence: %v", err) @@ -87,11 +134,10 @@ func CreateDirIfNotExists(path string) error { // // log.Printf("File %s not found in the inventory directories\n", relFilePath) // } -func PathExistsInInventory(relPath string) (bool, error) { - inventory := viper.GetStringSlice("inventory") - - for _, invPath := range inventory { - absPath := filepath.Join(invPath, relPath) +func PathExistsInInventory(relPath string, inventoryPaths []string) (bool, error) { + for _, invPath := range inventoryPaths { + invPath := ExpandHomeDir(invPath) + absPath := filepath.Join(invPath, "..", relPath) if _, err := os.Stat(absPath); err == nil { return true, nil } @@ -100,12 +146,13 @@ func PathExistsInInventory(relPath string) (bool, error) { return false, nil } -// TemplateExists checks if a template file exists in any of the inventory directories specified in the configuration -// file. If the template file is found, it returns true, otherwise, it returns false. +// TemplateExists checks if a template file exists in any of the inventory directories specified in the inventoryPaths +// parameter. If the template file is found, it returns true, otherwise, it returns false. // // Parameters: // -// templateName: A string representing the name of the template file to search for in the inventory directories. +// templatePath: A string representing the path of the template file to search for in the inventory directories. +// inventoryPaths: A []string containing the inventory directory paths to search. // // Returns: // @@ -114,30 +161,43 @@ func PathExistsInInventory(relPath string) (bool, error) { // // Example: // -// templateName := "exampleTTP" -// exists, err := TemplateExists(templateName) +// templatePath := "bash/bashTTP.yaml.tmpl" +// inventoryPaths := []string{"path/to/inventory1", "path/to/inventory2"} +// exists, err := TemplateExists(templatePath, inventoryPaths) // // if err != nil { // log.Fatalf("failed to check template existence: %v", err) // } // // if exists { -// log.Printf("Template %s found in the inventory directories\n", templateName) +// log.Printf("Template %s found in the inventory directories\n", templatePath) // } else { // -// log.Printf("Template %s not found in the inventory directories\n", templateName) +// log.Printf("Template %s not found in the inventory directories\n", templatePath) // } -func TemplateExists(templateName string) (bool, error) { - templatePath := filepath.Join("templates", templateName+"TTP.yaml.tmpl") - return PathExistsInInventory(templatePath) +func TemplateExists(templatePath string, inventoryPaths []string) (bool, error) { + for _, inventoryPath := range inventoryPaths { + fullPath := filepath.Join(inventoryPath, templatePath) + + // Check if the template exists at the fullPath + if _, err := os.Stat(fullPath); err == nil { + return true, nil + } else if !os.IsNotExist(err) { + return false, err + } + } + + // If the template is not found in any of the paths, return false + return false, nil } -// TTPExists checks if a TTP file exists in any of the inventory directories specified in the configuration file. +// TTPExists checks if a TTP file exists in any of the inventory directories specified in the inventoryPaths parameter. // If the TTP file is found, it returns true, otherwise, it returns false. // // Parameters: // // ttpName: A string representing the name of the TTP file to search for in the inventory directories. +// inventoryPaths: A []string containing the inventory directory paths to search. // // Returns: // @@ -147,7 +207,8 @@ func TemplateExists(templateName string) (bool, error) { // Example: // // ttpName := "exampleTTP" -// exists, err := TTPExists(ttpName) +// inventoryPaths := []string{"path/to/inventory1", "path/to/inventory2"} +// exists, err := TTPExists(ttpName, inventoryPaths) // // if err != nil { // log.Fatalf("failed to check TTP existence: %v", err) @@ -159,7 +220,7 @@ func TemplateExists(templateName string) (bool, error) { // // log.Printf("TTP %s not found in the inventory directories\n", ttpName) // } -func TTPExists(ttpName string) (bool, error) { +func TTPExists(ttpName string, inventoryPaths []string) (bool, error) { ttpPath := filepath.Join("ttps", ttpName+".yaml") - return PathExistsInInventory(ttpPath) + return PathExistsInInventory(ttpPath, inventoryPaths) } diff --git a/pkg/files/file_test.go b/pkg/files/file_test.go index 0ba2b4b5..d908c8d0 100644 --- a/pkg/files/file_test.go +++ b/pkg/files/file_test.go @@ -26,7 +26,6 @@ import ( "testing" "github.com/facebookincubator/ttpforge/pkg/files" - "github.com/spf13/viper" ) func TestCreateDirIfNotExists(t *testing.T) { @@ -76,29 +75,57 @@ func TestCreateDirIfNotExists(t *testing.T) { }) } - os.RemoveAll("testDir") + defer os.RemoveAll("testDir") } -func TestPathExistsInInventory(t *testing.T) { - testDir, err := os.MkdirTemp("", "inventory") - if err != nil { - t.Fatalf("failed to create temp dir: %v", err) +func createTestInventory(t *testing.T, dir string) { + t.Helper() + + lateralMovementDir := filepath.Join(dir, "lateral-movement", "ssh") + if err := os.MkdirAll(lateralMovementDir, 0755); err != nil { + t.Fatalf("failed to create lateral movement dir: %v", err) } - defer os.RemoveAll(testDir) - viper.Set("inventory", []string{testDir}) + privEscalationDir := filepath.Join(dir, "privilege-escalation", "credential-theft", "hello-world") + if err := os.MkdirAll(privEscalationDir, 0755); err != nil { + t.Fatalf("failed to create privilege escalation dir: %v", err) + } - filePath := filepath.Join(testDir, "test.txt") - file, err := os.Create(filePath) - if err != nil { - t.Fatalf("failed to create test file: %v", err) + testFiles := []struct { + path string + contents string + }{ + { + path: filepath.Join(lateralMovementDir, "rogue-ssh-key.yaml"), + contents: "test rogue ssh key contents", + }, + { + path: filepath.Join(privEscalationDir, "ttp.yaml"), + contents: "test ttp yaml contents", + }, + } + + for _, file := range testFiles { + f, err := os.Create(file.path) + if err != nil { + t.Fatalf("failed to create test file: %v", err) + } + if _, err := io.WriteString(f, file.contents); err != nil { + t.Fatalf("failed to write to test file: %v", err) + } + f.Close() } +} - if _, err := io.WriteString(file, "test content"); err != nil { - t.Fatalf("failed to write to test file: %v", err) +func TestPathExistsInInventory(t *testing.T) { + testDir, err := os.MkdirTemp("", "inventory") + if err != nil { + t.Fatalf("failed to create temp dir: %v", err) } + defer os.RemoveAll(testDir) - file.Close() + createTestInventory(t, testDir) + inventoryPaths := []string{testDir} tests := []struct { name string @@ -107,7 +134,12 @@ func TestPathExistsInInventory(t *testing.T) { }{ { name: "file exists in inventory", - relPath: "test.txt", + relPath: "lateral-movement/ssh/rogue-ssh-key.yaml", + expected: true, + }, + { + name: "file exists in inventory", + relPath: "privilege-escalation/credential-theft/hello-world/ttp.yaml", expected: true, }, { @@ -119,7 +151,7 @@ func TestPathExistsInInventory(t *testing.T) { for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { - exists, err := files.PathExistsInInventory(tc.relPath) + exists, err := files.PathExistsInInventory(tc.relPath, inventoryPaths) if err != nil { t.Fatalf("unexpected error: %v", err) } @@ -131,27 +163,80 @@ func TestPathExistsInInventory(t *testing.T) { } } +// Borrowed from: https://github.com/l50/goutils/blob/e91b7c4e18e23c53e35d04fa7961a5a14ca8ef39/fileutils_test.go#L294-L340 +func TestExpandHomeDir(t *testing.T) { + homeDir, err := os.UserHomeDir() + if err != nil { + t.Fatalf("failed to get user home directory: %v", err) + } + + testCases := []struct { + name string + input string + expected string + }{ + { + name: "EmptyPath", + input: "", + expected: "", + }, + { + name: "NoTilde", + input: "/path/without/tilde", + expected: "/path/without/tilde", + }, + { + name: "TildeOnly", + input: "~", + expected: homeDir, + }, + { + name: "TildeWithSlash", + input: "~/path/with/slash", + expected: filepath.Join(homeDir, "path/with/slash"), + }, + { + name: "TildeWithoutSlash", + input: "~path/without/slash", + expected: filepath.Join(homeDir, "path/without/slash"), + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + actual := files.ExpandHomeDir(tc.input) + if actual != tc.expected { + t.Errorf("test failed: ExpandHomeDir(%q) = %q; expected %q", tc.input, actual, tc.expected) + } + }) + } +} + func TestTemplateExists(t *testing.T) { - testDir, err := os.MkdirTemp("", "test") + testDir, err := os.MkdirTemp("", "inventory") if err != nil { t.Fatalf("failed to create temp dir: %v", err) } defer os.RemoveAll(testDir) - templateDir := filepath.Join(testDir, "templates") - if err := os.MkdirAll(templateDir, 0755); err != nil { + createTestInventory(t, testDir) + + bashTemplateDir := filepath.Join(testDir, "templates", "bash") + if err := os.MkdirAll(bashTemplateDir, 0755); err != nil { t.Fatalf("failed to create templates directory: %v", err) } - templatePath := filepath.Join(templateDir, "exampleTTP.yaml.tmpl") - templateFile, err := os.Create(templatePath) + templateFile, err := os.Create(filepath.Join(bashTemplateDir, "bashTTP.yaml.tmpl")) if err != nil { t.Fatalf("failed to create test template: %v", err) } + if _, err := io.WriteString(templateFile, "test template content"); err != nil { t.Fatalf("failed to write to test template: %v", err) } - templateFile.Close() + defer templateFile.Close() + + inventoryPaths := []string{testDir} tests := []struct { name string @@ -160,19 +245,81 @@ func TestTemplateExists(t *testing.T) { }{ { name: "template exists", - template: "exampleTTP", + template: "bash", shouldExist: true, }, { name: "template does not exist", - template: "nonexistentTTP", + template: "nonexistent", + shouldExist: false, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + // exists, err := files.TemplateExists(tc.template, inventoryPaths) + exists, err := files.TemplateExists(filepath.Join("templates", tc.template), inventoryPaths) + + if exists != tc.shouldExist { + t.Fatalf("expected %v, got %v", tc.shouldExist, exists) + } + + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if exists != tc.shouldExist { + t.Fatalf("expected %v, got %v", tc.shouldExist, exists) + } + }) + } +} + +func TestTTPExists(t *testing.T) { + testDir, err := os.MkdirTemp("", "inventory") + if err != nil { + t.Fatalf("failed to create temp dir: %v", err) + } + defer os.RemoveAll(testDir) + + createTestInventory(t, testDir) + inventoryPaths := []string{testDir} + + ttpDir := filepath.Join(testDir, "ttps") + if err := os.MkdirAll(ttpDir, 0755); err != nil { + t.Fatalf("failed to create ttps directory: %v", err) + } + + ttpPath := filepath.Join(ttpDir, "exampleTTP.yaml") + ttpFile, err := os.Create(ttpPath) + if err != nil { + t.Fatalf("failed to create test ttp: %v", err) + } + if _, err := io.WriteString(ttpFile, "test ttp content"); err != nil { + t.Fatalf("failed to write to test ttp: %v", err) + } + ttpFile.Close() + + tests := []struct { + name string + ttpName string + shouldExist bool + }{ + { + name: "TTP exists", + ttpName: "exampleTTP", + shouldExist: true, + }, + { + name: "TTP does not exist", + ttpName: "nonexistentTTP", shouldExist: false, }, } for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { - exists, err := files.TemplateExists(tc.template) + exists, err := files.TTPExists(tc.ttpName, inventoryPaths) if err != nil { t.Fatalf("unexpected error: %v", err) } diff --git a/pkg/files/yaml.go b/pkg/files/yaml.go index 9b44b0ab..006e4cf4 100644 --- a/pkg/files/yaml.go +++ b/pkg/files/yaml.go @@ -31,7 +31,7 @@ import ( // ExecuteYAML is the top-level TTP execution function // exported so that we can test it // the returned TTP is also required to assert against in tests -func ExecuteYAML(yamlFile string) (*blocks.TTP, error) { +func ExecuteYAML(yamlFile string, inventoryPaths []string) (*blocks.TTP, error) { ttp, err := blocks.LoadTTP(yamlFile) if err != nil { logging.Logger.Sugar().Errorw("failed to run TTP", zap.Error(err)) @@ -44,7 +44,7 @@ func ExecuteYAML(yamlFile string) (*blocks.TTP, error) { blocks.InventoryPath = []string{} for _, path := range inventory { - exists, err := PathExistsInInventory(path) + exists, err := PathExistsInInventory(path, inventoryPaths) if err != nil { return nil, err } diff --git a/pkg/files/yaml_test.go b/pkg/files/yaml_test.go index badfbbce..1d95a583 100644 --- a/pkg/files/yaml_test.go +++ b/pkg/files/yaml_test.go @@ -75,14 +75,22 @@ echo "you said: $1" "{\"output\":\"you said: wut\"}", }, }, - // Add more test cases here as needed. } for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { tempDir, err := os.MkdirTemp("", "e2e-tests") assert.NoError(t, err, "failed to create temporary directory") - defer os.RemoveAll(tempDir) // Clean up the temporary directory + // Clean up the temporary directory + defer os.RemoveAll(tempDir) + + // Create a temporary inventory file and write some content + inventoryPath := filepath.Join(tempDir, "inventory.txt") + err = os.WriteFile(inventoryPath, []byte("sample_inventory_content"), 0644) + assert.NoError(t, err, "failed to write the temporary inventory file") + + // Add the inventoryPath to the inventoryPaths slice + inventoryPaths := []string{inventoryPath} testYAMLPath := filepath.Join(tempDir, tc.testFile) err = os.WriteFile(testYAMLPath, []byte(testVariableExpansionYAML), 0644) @@ -92,7 +100,7 @@ echo "you said: $1" err = os.WriteFile(scriptPath, []byte(testVariableExpansionSH), 0755) assert.NoError(t, err, "failed to write the temporary shell script") - ttp, err := files.ExecuteYAML(testYAMLPath) + ttp, err := files.ExecuteYAML(testYAMLPath, inventoryPaths) assert.NoError(t, err, "execution of the testFile should not cause an error") assert.Equal(t, len(tc.stepOutputs), len(ttp.Steps), "step outputs should have correct length") diff --git a/templates/README.md.tmpl b/templates/bash/README.md.tmpl similarity index 100% rename from templates/README.md.tmpl rename to templates/bash/README.md.tmpl diff --git a/templates/bashTTP.sh.tmpl b/templates/bash/bashTTP.sh.tmpl similarity index 100% rename from templates/bashTTP.sh.tmpl rename to templates/bash/bashTTP.sh.tmpl diff --git a/templates/bashTTP.yaml.tmpl b/templates/bash/bashTTP.yaml.tmpl similarity index 100% rename from templates/bashTTP.yaml.tmpl rename to templates/bash/bashTTP.yaml.tmpl