diff --git a/input/input.go b/input/input.go index 12a382b..26c3247 100644 --- a/input/input.go +++ b/input/input.go @@ -244,20 +244,97 @@ func (inp *Input) ValidateExecutionModes() error { return nil } -// LoadInput loads the input file -func LoadInput(filename string) (*Input, error) { +// checkUnmappedModules compares the YAML map to the ModuleParams fields +func checkUnmappedModules(inp *Input, yamlMap map[string]interface{}) []string { + var unmappedModules []string + + scenarios, _ := yamlMap["scenarios"].([]interface{}) + // if !ok { + // return unmappedModules + // } + + for _, scenario := range scenarios { + scenarioMap, _ := scenario.(map[string]interface{}) + // if !ok { + // continue + // } + + parameters, _ := scenarioMap["parameters"].(map[string]interface{}) + // if !ok { + // continue + // } + + modules, _ := parameters["modules"].(map[string]interface{}) + // if !ok { + // continue + // } + + v := reflect.ValueOf(inp.Scenarios[0].Parameters.Modules) + t := v.Type() + + for key := range modules { + if key == "order" { + continue // Skip the "order" field as it's handled differently + } + + found := false + for i := 0; i < t.NumField(); i++ { + field := t.Field(i) + yamlTag := field.Tag.Get("yaml") + if yamlTag == key || (yamlTag == "" && strings.ToLower(field.Name) == key) { + found = true + break + } + } + if !found { + unmappedModules = append(unmappedModules, key) + } + } + } + + return unmappedModules +} +// LoadInput loads the input file and checks for unmapped modules +func LoadInput(filename string) (*Input, error) { yamlFile, err := os.ReadFile(filename) if err != nil { return nil, err } - inp := &Input{} - err = yaml.Unmarshal(yamlFile, inp) + // Unmarshal into a map to get all fields from YAML + var yamlMap map[string]interface{} + err = yaml.Unmarshal(yamlFile, &yamlMap) if err != nil { return nil, err } + // Unmarshal into the Input struct + inp := &Input{} + _ = yaml.Unmarshal(yamlFile, inp) + // No checking the errors here since the unmarshalling above will catch it + // if _ != nil { + // return nil, err + // } + + if utils.IsHaddock3(inp.General.HaddockDir) { + + // Check for unmapped modules + unmappedModules := checkUnmappedModules(inp, yamlMap) + if len(unmappedModules) > 0 { + + _s := "Module" + _a := "was" + if len(unmappedModules) > 1 { + _s = "Modules" + _a = "were" + } + unknownModules := strings.Join(unmappedModules, ", ") + return nil, errors.New(_s + " `" + unknownModules + "` " + _a + " not found in the HADDOCK installation or supported in this version.") + + } + } + return inp, nil } diff --git a/input/input_test.go b/input/input_test.go index 5bd9054..fc562d0 100644 --- a/input/input_test.go +++ b/input/input_test.go @@ -96,6 +96,75 @@ scenarios: param1: value1 `) + // Write a file with unknown modules + // Fake a haddock3 directory structure with a defaults.yaml file + temp_dir := "_testloadinput_haddock3" + _ = os.MkdirAll(filepath.Join(temp_dir, "src/haddock/modules"), 0755) + _ = os.WriteFile(filepath.Join(temp_dir, "src/haddock/modules/defaults.yaml"), []byte(""), 0755) + // Create a known module + // _ = os.MkdirAll(filepath.Join(temp_dir, "src/haddock/modules/rigidbody"), 0755) + defer os.RemoveAll(temp_dir) + + d4 := []byte(`general: + executable: /home/rodrigo/repos/haddock-runner/haddock3.sh + max_concurrent: 999 + haddock_dir: _testloadinput_haddock3 + receptor_suffix: _r_u + ligand_suffix: _l_u + input_list: example/input_list.txt + work_dir: bm-goes-here + +scenarios: + - name: true-interface + parameters: + run_cns: + noecv: false + restraints: + ambig: ti + custom_toppar: + topology: _ligand.top + general: + ncores: 1 + modules: + order: [unknown, caprieval.1, caprieval.2] + unknown: + param1: value1 + caprieval.1: + param1: value1 + caprieval.2: + param1: value1 +`) + + d5 := []byte(`general: + executable: /home/rodrigo/repos/haddock-runner/haddock3.sh + max_concurrent: 999 + haddock_dir: _testloadinput_haddock3 + receptor_suffix: _r_u + ligand_suffix: _l_u + input_list: example/input_list.txt + work_dir: bm-goes-here + +scenarios: + - name: true-interface + parameters: + run_cns: + noecv: false + restraints: + ambig: ti + custom_toppar: + topology: _ligand.top + general: + ncores: 1 + modules: + order: [rigidbody, unknown, unknownalso] + rigidbody: + param1: value1 + unknown: + param1: value1 + unknownalso: + param1: value1 +`) + err := os.WriteFile("test-input.yaml", d1, 0644) if err != nil { t.Errorf("Failed to write input file: %s", err) @@ -117,6 +186,20 @@ scenarios: defer os.Remove("test-input-repeated.yaml") + err = os.WriteFile("test-input-unknown.yaml", d4, 0644) + if err != nil { + t.Errorf("Failed to write input file: %s", err) + } + + defer os.Remove("test-input-unknown.yaml") + + err = os.WriteFile("test-input-unknown-twice.yaml", d5, 0644) + if err != nil { + t.Errorf("Failed to write input file: %s", err) + } + + defer os.Remove("test-input-unknown-twice.yaml") + type args struct { filename string } @@ -235,6 +318,30 @@ scenarios: }, wantErr: false, }, + { + name: "unknown module", + args: args{ + filename: "test-input-unknown.yaml", + }, + want: nil, + wantErr: true, + }, + { + name: "unknown modules", + args: args{ + filename: "test-input-unknown-twice.yaml", + }, + want: nil, + wantErr: true, + }, + { + name: "wrong type", + args: args{ + filename: "test-input-wrong-type.yaml", + }, + want: nil, + wantErr: true, + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) {