diff --git a/session_pool.go b/session_pool.go index 43f1aef9..ea4f96af 100644 --- a/session_pool.go +++ b/session_pool.go @@ -12,6 +12,7 @@ import ( "container/list" "fmt" "strconv" + "strings" "sync" "time" @@ -320,6 +321,22 @@ func (pool *SessionPool) CreateTag(tag LabelSchema) (*ResultSet, error) { return rs, nil } +func (pool *SessionPool) ApplyTag(tag LabelSchema) (*ResultSet, error) { + // 1. Check if the tag exists + _, err := pool.DescTag(tag.Name) + fmt.Println("DEBUG: apply tag") + fmt.Println(err) + if err != nil { + // 2. If the tag does not exist, create it + if strings.Contains(strings.ToLower(err.Error()), "not exist") { + return pool.CreateTag(tag) + } + return nil, err + } + + return nil, nil +} + func (pool *SessionPool) DescTag(tagName string) ([]Label, error) { q := fmt.Sprintf("DESC TAG %s;", tagName) rs, err := pool.ExecuteAndCheck(q) diff --git a/session_pool_test.go b/session_pool_test.go index 1fd3e166..b1379eb9 100644 --- a/session_pool_test.go +++ b/session_pool_test.go @@ -359,18 +359,19 @@ func TestSessionPoolSpaceChange(t *testing.T) { } func TestSessionPoolApplySchema(t *testing.T) { - err := prepareSpace("test_space_schema") + spaceName := "test_space_schema" + err := prepareSpace(spaceName) if err != nil { t.Fatal(err) } - defer dropSpace("test_space_schema") + defer dropSpace(spaceName) hostAddress := HostAddress{Host: address, Port: port} config, err := NewSessionPoolConf( "root", "nebula", []HostAddress{hostAddress}, - "test_space_schema") + spaceName) if err != nil { t.Errorf("failed to create session pool config, %s", err.Error()) } @@ -394,7 +395,7 @@ func TestSessionPoolApplySchema(t *testing.T) { for _, space := range spaces { spaceNames = append(spaceNames, space.Name) } - assert.Contains(t, spaceNames, "test_space_schema", "should have test_space_schema") + assert.Contains(t, spaceNames, spaceName) tagSchema := LabelSchema{ Name: "account", @@ -464,6 +465,86 @@ func TestSessionPoolApplySchema(t *testing.T) { assert.Equal(t, "string", labels[0].Type, "field type should be string") } +func TestSessionPoolApplyTag(t *testing.T) { + spaceName := "test_space_apply_tag" + err := prepareSpace(spaceName) + if err != nil { + t.Fatal(err) + } + defer dropSpace(spaceName) + + hostAddress := HostAddress{Host: address, Port: port} + config, err := NewSessionPoolConf( + "root", + "nebula", + []HostAddress{hostAddress}, + spaceName) + if err != nil { + t.Errorf("failed to create session pool config, %s", err.Error()) + } + + // allow only one session in the pool so it is easier to test + config.maxSize = 1 + + // create session pool + sessionPool, err := NewSessionPool(*config, DefaultLogger{}) + if err != nil { + t.Fatal(err) + } + defer sessionPool.Close() + + spaces, err := sessionPool.ShowSpaces() + if err != nil { + t.Fatal(err) + } + assert.LessOrEqual(t, 1, len(spaces), "should have at least 1 space") + var spaceNames []string + for _, space := range spaces { + spaceNames = append(spaceNames, space.Name) + } + assert.Contains(t, spaceNames, spaceName) + + tagSchema := LabelSchema{ + Name: "account", + Fields: []LabelFieldSchema{ + { + Field: "name", + Nullable: false, + }, + // { + // Field: "email", + // Nullable: true, + // }, + // { + // Field: "phone", + // Type: "int64", + // Nullable: true, + // }, + }, + } + _, err = sessionPool.ApplyTag(tagSchema) + if err != nil { + t.Fatal(err) + } + tags, err := sessionPool.ShowTags() + if err != nil { + t.Fatal(err) + } + assert.Equal(t, 1, len(tags), "should have 1 tags") + assert.Equal(t, "account", tags[0].Name, "tag name should be account") + labels, err := sessionPool.DescTag("account") + if err != nil { + t.Fatal(err) + } + assert.Equal(t, 1, len(labels), "should have 1 labels") + assert.Equal(t, "name", labels[0].Field, "field name should be name") + assert.Equal(t, "string", labels[0].Type, "field type should be string") + // assert.Equal(t, "email", labels[1].Field, "field name should be email") + // assert.Equal(t, "string", labels[1].Type, "field type should be string") + // assert.Equal(t, "phone", labels[2].Field, "field name should be phone") + // assert.Equal(t, "int64", labels[2].Type, "field type should be int64") +} + func TestIdleSessionCleaner(t *testing.T) { err := prepareSpace("client_test") if err != nil {