From 310007797c893aaf2b19f488c870fa064580dcb3 Mon Sep 17 00:00:00 2001
From: dave <logindaveye@gmail.com>
Date: Tue, 3 Aug 2021 19:11:06 +0800
Subject: [PATCH 1/8] util/topsql/reporter: migrate test-infra to testify
 (#26721)

---
 util/topsql/reporter/main_test.go     |  26 +++
 util/topsql/reporter/reporter_test.go | 219 ++++++++++++--------------
 2 files changed, 128 insertions(+), 117 deletions(-)
 create mode 100644 util/topsql/reporter/main_test.go

diff --git a/util/topsql/reporter/main_test.go b/util/topsql/reporter/main_test.go
new file mode 100644
index 0000000000000..84505af739be9
--- /dev/null
+++ b/util/topsql/reporter/main_test.go
@@ -0,0 +1,26 @@
+// Copyright 2021 PingCAP, Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package reporter
+
+import (
+	"testing"
+
+	"github.com/pingcap/tidb/util/testbridge"
+	"go.uber.org/goleak"
+)
+
+func TestMain(m *testing.M) {
+	testbridge.WorkaroundGoCheckFlags()
+	goleak.VerifyTestMain(m)
+}
diff --git a/util/topsql/reporter/reporter_test.go b/util/topsql/reporter/reporter_test.go
index d4876e4863e16..34d454aca3280 100644
--- a/util/topsql/reporter/reporter_test.go
+++ b/util/topsql/reporter/reporter_test.go
@@ -20,29 +20,17 @@ import (
 	"testing"
 	"time"
 
-	. "github.com/pingcap/check"
 	"github.com/pingcap/tidb/sessionctx/variable"
 	"github.com/pingcap/tidb/util/topsql/reporter/mock"
 	"github.com/pingcap/tidb/util/topsql/tracecpu"
 	"github.com/pingcap/tipb/go-tipb"
+	"github.com/stretchr/testify/require"
 )
 
 const (
 	maxSQLNum = 5000
 )
 
-func TestT(t *testing.T) {
-	TestingT(t)
-}
-
-var _ = SerialSuites(&testTopSQLReporter{})
-
-type testTopSQLReporter struct{}
-
-func (s *testTopSQLReporter) SetUpSuite(c *C) {}
-
-func (s *testTopSQLReporter) SetUpTest(c *C) {}
-
 func populateCache(tsr *RemoteTopSQLReporter, begin, end int, timestamp uint64) {
 	// register normalized sql
 	for i := begin; i < end; i++ {
@@ -90,9 +78,9 @@ func initializeCache(maxStatementsNum, interval int, addr string) *RemoteTopSQLR
 	return ts
 }
 
-func (s *testTopSQLReporter) TestCollectAndSendBatch(c *C) {
+func TestCollectAndSendBatch(t *testing.T) {
 	agentServer, err := mock.StartMockAgentServer()
-	c.Assert(err, IsNil)
+	require.NoError(t, err)
 	defer agentServer.Stop()
 
 	tsr := setupRemoteTopSQLReporter(maxSQLNum, 1, agentServer.Address())
@@ -100,8 +88,7 @@ func (s *testTopSQLReporter) TestCollectAndSendBatch(c *C) {
 	populateCache(tsr, 0, maxSQLNum, 1)
 
 	agentServer.WaitCollectCnt(1, time.Second*5)
-
-	c.Assert(agentServer.GetLatestRecords(), HasLen, maxSQLNum)
+	require.Len(t, agentServer.GetLatestRecords(), maxSQLNum)
 
 	// check for equality of server received batch and the original data
 	records := agentServer.GetLatestRecords()
@@ -110,29 +97,29 @@ func (s *testTopSQLReporter) TestCollectAndSendBatch(c *C) {
 		prefix := "sqlDigest"
 		if strings.HasPrefix(string(req.SqlDigest), prefix) {
 			n, err := strconv.Atoi(string(req.SqlDigest)[len(prefix):])
-			c.Assert(err, IsNil)
+			require.NoError(t, err)
 			id = n
 		}
-		c.Assert(req.RecordListCpuTimeMs, HasLen, 1)
+		require.Len(t, req.RecordListCpuTimeMs, 1)
 		for i := range req.RecordListCpuTimeMs {
-			c.Assert(req.RecordListCpuTimeMs[i], Equals, uint32(id))
+			require.Equal(t, uint32(id), req.RecordListCpuTimeMs[i])
 		}
-		c.Assert(req.RecordListTimestampSec, HasLen, 1)
+		require.Len(t, req.RecordListTimestampSec, 1)
 		for i := range req.RecordListTimestampSec {
-			c.Assert(req.RecordListTimestampSec[i], Equals, uint64(1))
+			require.Equal(t, uint64(1), req.RecordListTimestampSec[i])
 		}
 		sqlMeta, exist := agentServer.GetSQLMetaByDigestBlocking(req.SqlDigest, time.Second)
-		c.Assert(exist, IsTrue)
-		c.Assert(sqlMeta.NormalizedSql, Equals, "sqlNormalized"+strconv.Itoa(id))
+		require.True(t, exist)
+		require.Equal(t, "sqlNormalized"+strconv.Itoa(id), sqlMeta.NormalizedSql)
 		normalizedPlan, exist := agentServer.GetPlanMetaByDigestBlocking(req.PlanDigest, time.Second)
-		c.Assert(exist, IsTrue)
-		c.Assert(normalizedPlan, Equals, "planNormalized"+strconv.Itoa(id))
+		require.True(t, exist)
+		require.Equal(t, "planNormalized"+strconv.Itoa(id), normalizedPlan)
 	}
 }
 
-func (s *testTopSQLReporter) TestCollectAndEvicted(c *C) {
+func TestCollectAndEvicted(t *testing.T) {
 	agentServer, err := mock.StartMockAgentServer()
-	c.Assert(err, IsNil)
+	require.NoError(t, err)
 	defer agentServer.Stop()
 
 	tsr := setupRemoteTopSQLReporter(maxSQLNum, 1, agentServer.Address())
@@ -143,38 +130,38 @@ func (s *testTopSQLReporter) TestCollectAndEvicted(c *C) {
 
 	// check for equality of server received batch and the original data
 	records := agentServer.GetLatestRecords()
-	c.Assert(records, HasLen, maxSQLNum+1)
+	require.Len(t, records, maxSQLNum+1)
 	for _, req := range records {
 		id := 0
 		prefix := "sqlDigest"
 		if strings.HasPrefix(string(req.SqlDigest), prefix) {
 			n, err := strconv.Atoi(string(req.SqlDigest)[len(prefix):])
-			c.Assert(err, IsNil)
+			require.NoError(t, err)
 			id = n
 		}
-		c.Assert(req.RecordListTimestampSec, HasLen, 1)
-		c.Assert(req.RecordListTimestampSec[0], Equals, uint64(2))
-		c.Assert(req.RecordListCpuTimeMs, HasLen, 1)
+		require.Len(t, req.RecordListTimestampSec, 1)
+		require.Equal(t, uint64(2), req.RecordListTimestampSec[0])
+		require.Len(t, req.RecordListCpuTimeMs, 1)
 		if id == 0 {
 			// test for others
-			c.Assert(req.SqlDigest, IsNil)
-			c.Assert(req.PlanDigest, IsNil)
+			require.Nil(t, req.SqlDigest)
+			require.Nil(t, req.PlanDigest)
 			// 12502500 is the sum of all evicted item's cpu time. 1 + 2 + 3 + ... + 5000 = (1 + 5000) * 2500 = 12502500
-			c.Assert(int(req.RecordListCpuTimeMs[0]), Equals, 12502500)
+			require.Equal(t, 12502500, int(req.RecordListCpuTimeMs[0]))
 			continue
 		}
-		c.Assert(id > maxSQLNum, IsTrue)
-		c.Assert(req.RecordListCpuTimeMs[0], Equals, uint32(id))
+		require.Greater(t, id, maxSQLNum)
+		require.Equal(t, uint32(id), req.RecordListCpuTimeMs[0])
 		sqlMeta, exist := agentServer.GetSQLMetaByDigestBlocking(req.SqlDigest, time.Second)
-		c.Assert(exist, IsTrue)
-		c.Assert(sqlMeta.NormalizedSql, Equals, "sqlNormalized"+strconv.Itoa(id))
+		require.True(t, exist)
+		require.Equal(t, "sqlNormalized"+strconv.Itoa(id), sqlMeta.NormalizedSql)
 		normalizedPlan, exist := agentServer.GetPlanMetaByDigestBlocking(req.PlanDigest, time.Second)
-		c.Assert(exist, IsTrue)
-		c.Assert(normalizedPlan, Equals, "planNormalized"+strconv.Itoa(id))
+		require.True(t, exist)
+		require.Equal(t, "planNormalized"+strconv.Itoa(id), normalizedPlan)
 	}
 }
 
-func (s *testTopSQLReporter) newSQLCPUTimeRecord(tsr *RemoteTopSQLReporter, sqlID int, cpuTimeMs uint32) tracecpu.SQLCPUTimeRecord {
+func newSQLCPUTimeRecord(tsr *RemoteTopSQLReporter, sqlID int, cpuTimeMs uint32) tracecpu.SQLCPUTimeRecord {
 	key := []byte("sqlDigest" + strconv.Itoa(sqlID))
 	value := "sqlNormalized" + strconv.Itoa(sqlID)
 	tsr.RegisterSQL(key, value, sqlID%2 == 0)
@@ -190,56 +177,56 @@ func (s *testTopSQLReporter) newSQLCPUTimeRecord(tsr *RemoteTopSQLReporter, sqlI
 	}
 }
 
-func (s *testTopSQLReporter) collectAndWait(tsr *RemoteTopSQLReporter, timestamp uint64, records []tracecpu.SQLCPUTimeRecord) {
+func collectAndWait(tsr *RemoteTopSQLReporter, timestamp uint64, records []tracecpu.SQLCPUTimeRecord) {
 	tsr.Collect(timestamp, records)
 	time.Sleep(time.Millisecond * 100)
 }
 
-func (s *testTopSQLReporter) TestCollectAndTopN(c *C) {
+func TestCollectAndTopN(t *testing.T) {
 	agentServer, err := mock.StartMockAgentServer()
-	c.Assert(err, IsNil)
+	require.NoError(t, err)
 	defer agentServer.Stop()
 
 	tsr := setupRemoteTopSQLReporter(2, 1, agentServer.Address())
 	defer tsr.Close()
 
 	records := []tracecpu.SQLCPUTimeRecord{
-		s.newSQLCPUTimeRecord(tsr, 1, 1),
-		s.newSQLCPUTimeRecord(tsr, 2, 2),
+		newSQLCPUTimeRecord(tsr, 1, 1),
+		newSQLCPUTimeRecord(tsr, 2, 2),
 	}
-	s.collectAndWait(tsr, 1, records)
+	collectAndWait(tsr, 1, records)
 
 	records = []tracecpu.SQLCPUTimeRecord{
-		s.newSQLCPUTimeRecord(tsr, 3, 3),
-		s.newSQLCPUTimeRecord(tsr, 1, 1),
+		newSQLCPUTimeRecord(tsr, 3, 3),
+		newSQLCPUTimeRecord(tsr, 1, 1),
 	}
-	s.collectAndWait(tsr, 2, records)
+	collectAndWait(tsr, 2, records)
 
 	records = []tracecpu.SQLCPUTimeRecord{
-		s.newSQLCPUTimeRecord(tsr, 4, 1),
-		s.newSQLCPUTimeRecord(tsr, 1, 1),
+		newSQLCPUTimeRecord(tsr, 4, 1),
+		newSQLCPUTimeRecord(tsr, 1, 1),
 	}
-	s.collectAndWait(tsr, 3, records)
+	collectAndWait(tsr, 3, records)
 
 	records = []tracecpu.SQLCPUTimeRecord{
-		s.newSQLCPUTimeRecord(tsr, 5, 1),
-		s.newSQLCPUTimeRecord(tsr, 1, 1),
+		newSQLCPUTimeRecord(tsr, 5, 1),
+		newSQLCPUTimeRecord(tsr, 1, 1),
 	}
-	s.collectAndWait(tsr, 4, records)
+	collectAndWait(tsr, 4, records)
 
 	// Test for time jump back.
 	records = []tracecpu.SQLCPUTimeRecord{
-		s.newSQLCPUTimeRecord(tsr, 6, 1),
-		s.newSQLCPUTimeRecord(tsr, 1, 1),
+		newSQLCPUTimeRecord(tsr, 6, 1),
+		newSQLCPUTimeRecord(tsr, 1, 1),
 	}
-	s.collectAndWait(tsr, 0, records)
+	collectAndWait(tsr, 0, records)
 
 	// Wait agent server collect finish.
 	agentServer.WaitCollectCnt(1, time.Second*10)
 
 	// check for equality of server received batch and the original data
 	results := agentServer.GetLatestRecords()
-	c.Assert(results, HasLen, 3)
+	require.Len(t, results, 3)
 	sort.Slice(results, func(i, j int) bool {
 		return string(results[i].SqlDigest) < string(results[j].SqlDigest)
 	})
@@ -250,17 +237,17 @@ func (s *testTopSQLReporter) TestCollectAndTopN(c *C) {
 		}
 		return int(total)
 	}
-	c.Assert(results[0].SqlDigest, IsNil)
-	c.Assert(getTotalCPUTime(results[0]), Equals, 5)
-	c.Assert(results[0].RecordListTimestampSec, DeepEquals, []uint64{0, 1, 3, 4})
-	c.Assert(results[0].RecordListCpuTimeMs, DeepEquals, []uint32{1, 2, 1, 1})
-	c.Assert(results[1].SqlDigest, DeepEquals, []byte("sqlDigest1"))
-	c.Assert(getTotalCPUTime(results[1]), Equals, 5)
-	c.Assert(results[2].SqlDigest, DeepEquals, []byte("sqlDigest3"))
-	c.Assert(getTotalCPUTime(results[2]), Equals, 3)
+	require.Nil(t, results[0].SqlDigest)
+	require.Equal(t, 5, getTotalCPUTime(results[0]))
+	require.Equal(t, []uint64{0, 1, 3, 4}, results[0].RecordListTimestampSec)
+	require.Equal(t, []uint32{1, 2, 1, 1}, results[0].RecordListCpuTimeMs)
+	require.Equal(t, []byte("sqlDigest1"), results[1].SqlDigest)
+	require.Equal(t, 5, getTotalCPUTime(results[1]))
+	require.Equal(t, []byte("sqlDigest3"), results[2].SqlDigest)
+	require.Equal(t, 3, getTotalCPUTime(results[2]))
 }
 
-func (s *testTopSQLReporter) TestCollectCapacity(c *C) {
+func TestCollectCapacity(t *testing.T) {
 	tsr := setupRemoteTopSQLReporter(maxSQLNum, 60, "")
 	defer tsr.Close()
 
@@ -292,41 +279,41 @@ func (s *testTopSQLReporter) TestCollectCapacity(c *C) {
 
 	variable.TopSQLVariable.MaxCollect.Store(10000)
 	registerSQL(5000)
-	c.Assert(tsr.sqlMapLength.Load(), Equals, int64(5000))
+	require.Equal(t, int64(5000), tsr.sqlMapLength.Load())
 	registerPlan(1000)
-	c.Assert(tsr.planMapLength.Load(), Equals, int64(1000))
+	require.Equal(t, int64(1000), tsr.planMapLength.Load())
 
 	registerSQL(20000)
-	c.Assert(tsr.sqlMapLength.Load(), Equals, int64(10000))
+	require.Equal(t, int64(10000), tsr.sqlMapLength.Load())
 	registerPlan(20000)
-	c.Assert(tsr.planMapLength.Load(), Equals, int64(10000))
+	require.Equal(t, int64(10000), tsr.planMapLength.Load())
 
 	variable.TopSQLVariable.MaxCollect.Store(20000)
 	registerSQL(50000)
-	c.Assert(tsr.sqlMapLength.Load(), Equals, int64(20000))
+	require.Equal(t, int64(20000), tsr.sqlMapLength.Load())
 	registerPlan(50000)
-	c.Assert(tsr.planMapLength.Load(), Equals, int64(20000))
+	require.Equal(t, int64(20000), tsr.planMapLength.Load())
 
 	variable.TopSQLVariable.MaxStatementCount.Store(5000)
 	collectedData := make(map[string]*dataPoints)
 	tsr.doCollect(collectedData, 1, genRecord(20000))
-	c.Assert(len(collectedData), Equals, 5001)
-	c.Assert(tsr.sqlMapLength.Load(), Equals, int64(5000))
-	c.Assert(tsr.planMapLength.Load(), Equals, int64(5000))
+	require.Equal(t, 5001, len(collectedData))
+	require.Equal(t, int64(5000), tsr.sqlMapLength.Load())
+	require.Equal(t, int64(5000), tsr.planMapLength.Load())
 }
 
-func (s *testTopSQLReporter) TestCollectOthers(c *C) {
+func TestCollectOthers(t *testing.T) {
 	collectTarget := make(map[string]*dataPoints)
 	addEvictedCPUTime(collectTarget, 1, 10)
 	addEvictedCPUTime(collectTarget, 2, 20)
 	addEvictedCPUTime(collectTarget, 3, 30)
 	others := collectTarget[keyOthers]
-	c.Assert(others.CPUTimeMsTotal, Equals, uint64(60))
-	c.Assert(others.TimestampList, DeepEquals, []uint64{1, 2, 3})
-	c.Assert(others.CPUTimeMsList, DeepEquals, []uint32{10, 20, 30})
+	require.Equal(t, uint64(60), others.CPUTimeMsTotal)
+	require.Equal(t, []uint64{1, 2, 3}, others.TimestampList)
+	require.Equal(t, []uint32{10, 20, 30}, others.CPUTimeMsList)
 
 	others = addEvictedIntoSortedDataPoints(nil, others)
-	c.Assert(others.CPUTimeMsTotal, Equals, uint64(60))
+	require.Equal(t, uint64(60), others.CPUTimeMsTotal)
 
 	// test for time jump backward.
 	evict := &dataPoints{}
@@ -334,25 +321,25 @@ func (s *testTopSQLReporter) TestCollectOthers(c *C) {
 	evict.CPUTimeMsList = []uint32{30, 20, 40}
 	evict.CPUTimeMsTotal = 90
 	others = addEvictedIntoSortedDataPoints(others, evict)
-	c.Assert(others.CPUTimeMsTotal, Equals, uint64(150))
-	c.Assert(others.TimestampList, DeepEquals, []uint64{1, 2, 3, 4})
-	c.Assert(others.CPUTimeMsList, DeepEquals, []uint32{10, 40, 60, 40})
+	require.Equal(t, uint64(150), others.CPUTimeMsTotal)
+	require.Equal(t, []uint64{1, 2, 3, 4}, others.TimestampList)
+	require.Equal(t, []uint32{10, 40, 60, 40}, others.CPUTimeMsList)
 }
 
-func (s *testTopSQLReporter) TestDataPoints(c *C) {
+func TestDataPoints(t *testing.T) {
 	// test for dataPoints invalid.
 	d := &dataPoints{}
 	d.TimestampList = []uint64{1}
 	d.CPUTimeMsList = []uint32{10, 30}
-	c.Assert(d.isInvalid(), Equals, true)
+	require.True(t, d.isInvalid())
 
 	// test for dataPoints sort.
 	d = &dataPoints{}
 	d.TimestampList = []uint64{1, 2, 5, 6, 3, 4}
 	d.CPUTimeMsList = []uint32{10, 20, 50, 60, 30, 40}
 	sort.Sort(d)
-	c.Assert(d.TimestampList, DeepEquals, []uint64{1, 2, 3, 4, 5, 6})
-	c.Assert(d.CPUTimeMsList, DeepEquals, []uint32{10, 20, 30, 40, 50, 60})
+	require.Equal(t, []uint64{1, 2, 3, 4, 5, 6}, d.TimestampList)
+	require.Equal(t, []uint32{10, 20, 30, 40, 50, 60}, d.CPUTimeMsList)
 
 	// test for dataPoints merge.
 	d = &dataPoints{}
@@ -362,17 +349,17 @@ func (s *testTopSQLReporter) TestDataPoints(c *C) {
 	evict.CPUTimeMsList = []uint32{10, 30}
 	evict.CPUTimeMsTotal = 40
 	addEvictedIntoSortedDataPoints(d, evict)
-	c.Assert(d.CPUTimeMsTotal, Equals, uint64(40))
-	c.Assert(d.TimestampList, DeepEquals, []uint64{1, 3})
-	c.Assert(d.CPUTimeMsList, DeepEquals, []uint32{10, 30})
+	require.Equal(t, uint64(40), d.CPUTimeMsTotal)
+	require.Equal(t, []uint64{1, 3}, d.TimestampList)
+	require.Equal(t, []uint32{10, 30}, d.CPUTimeMsList)
 
 	evict.TimestampList = []uint64{1, 2, 3, 4, 5}
 	evict.CPUTimeMsList = []uint32{10, 20, 30, 40, 50}
 	evict.CPUTimeMsTotal = 150
 	addEvictedIntoSortedDataPoints(d, evict)
-	c.Assert(d.CPUTimeMsTotal, Equals, uint64(190))
-	c.Assert(d.TimestampList, DeepEquals, []uint64{1, 2, 3, 4, 5})
-	c.Assert(d.CPUTimeMsList, DeepEquals, []uint32{20, 20, 60, 40, 50})
+	require.Equal(t, uint64(190), d.CPUTimeMsTotal)
+	require.Equal(t, []uint64{1, 2, 3, 4, 5}, d.TimestampList)
+	require.Equal(t, []uint32{20, 20, 60, 40, 50}, d.CPUTimeMsList)
 
 	// test for time jump backward.
 	d = &dataPoints{}
@@ -381,56 +368,54 @@ func (s *testTopSQLReporter) TestDataPoints(c *C) {
 	evict.CPUTimeMsList = []uint32{30, 20}
 	evict.CPUTimeMsTotal = 50
 	addEvictedIntoSortedDataPoints(d, evict)
-	c.Assert(d.CPUTimeMsTotal, Equals, uint64(50))
-	c.Assert(d.TimestampList, DeepEquals, []uint64{2, 3})
-	c.Assert(d.CPUTimeMsList, DeepEquals, []uint32{20, 30})
+	require.Equal(t, uint64(50), d.CPUTimeMsTotal)
+	require.Equal(t, []uint64{2, 3}, d.TimestampList)
+	require.Equal(t, []uint32{20, 30}, d.CPUTimeMsList)
 
 	// test for merge invalid dataPoints
 	d = &dataPoints{}
 	evict = &dataPoints{}
 	evict.TimestampList = []uint64{1}
 	evict.CPUTimeMsList = []uint32{10, 30}
-	c.Assert(evict.isInvalid(), Equals, true)
+	require.True(t, evict.isInvalid())
 	addEvictedIntoSortedDataPoints(d, evict)
-	c.Assert(d.isInvalid(), Equals, false)
-	c.Assert(d.CPUTimeMsList, IsNil)
-	c.Assert(d.TimestampList, IsNil)
+	require.False(t, d.isInvalid())
+	require.Nil(t, d.CPUTimeMsList)
+	require.Nil(t, d.TimestampList)
 }
 
-func (s *testTopSQLReporter) TestCollectInternal(c *C) {
+func TestCollectInternal(t *testing.T) {
 	agentServer, err := mock.StartMockAgentServer()
-	c.Assert(err, IsNil)
+	require.NoError(t, err)
 	defer agentServer.Stop()
 
 	tsr := setupRemoteTopSQLReporter(3000, 1, agentServer.Address())
 	defer tsr.Close()
 
 	records := []tracecpu.SQLCPUTimeRecord{
-		s.newSQLCPUTimeRecord(tsr, 1, 1),
-		s.newSQLCPUTimeRecord(tsr, 2, 2),
+		newSQLCPUTimeRecord(tsr, 1, 1),
+		newSQLCPUTimeRecord(tsr, 2, 2),
 	}
-	s.collectAndWait(tsr, 1, records)
+	collectAndWait(tsr, 1, records)
 
 	// Wait agent server collect finish.
 	agentServer.WaitCollectCnt(1, time.Second*10)
 
 	// check for equality of server received batch and the original data
 	results := agentServer.GetLatestRecords()
-	c.Assert(results, HasLen, 2)
+	require.Len(t, results, 2)
 	for _, req := range results {
 		id := 0
 		prefix := "sqlDigest"
 		if strings.HasPrefix(string(req.SqlDigest), prefix) {
 			n, err := strconv.Atoi(string(req.SqlDigest)[len(prefix):])
-			c.Assert(err, IsNil)
+			require.NoError(t, err)
 			id = n
 		}
-		if id == 0 {
-			c.Fatalf("the id should not be 0")
-		}
+		require.NotEqualf(t, 0, id, "the id should not be 0")
 		sqlMeta, exist := agentServer.GetSQLMetaByDigestBlocking(req.SqlDigest, time.Second)
-		c.Assert(exist, IsTrue)
-		c.Assert(sqlMeta.IsInternalSql, Equals, id%2 == 0)
+		require.True(t, exist)
+		require.Equal(t, id%2 == 0, sqlMeta.IsInternalSql)
 	}
 }
 

From dd5546dfd0ad8376b3c7188da7c19d8c3b1818ab Mon Sep 17 00:00:00 2001
From: Yuanjia Zhang <zhangyuanjia@pingcap.com>
Date: Tue, 3 Aug 2021 19:21:06 +0800
Subject: [PATCH 2/8] Revert "planner: fix the issue that UnionScan returns
 wrong results in dynamic mode" (#26853)

---
 cmd/explaintest/r/generated_columns.result | 35 ++++------------------
 cmd/explaintest/r/select.result            | 22 +++++---------
 planner/core/integration_test.go           | 14 ---------
 sessionctx/variable/session.go             |  5 ----
 4 files changed, 13 insertions(+), 63 deletions(-)

diff --git a/cmd/explaintest/r/generated_columns.result b/cmd/explaintest/r/generated_columns.result
index d7f120eb28f3f..761dfc6053354 100644
--- a/cmd/explaintest/r/generated_columns.result
+++ b/cmd/explaintest/r/generated_columns.result
@@ -105,37 +105,14 @@ PARTITION p5 VALUES LESS THAN (6),
 PARTITION max VALUES LESS THAN MAXVALUE);
 EXPLAIN format = 'brief' SELECT * FROM sgc3 WHERE a <= 1;
 id	estRows	task	access object	operator info
-PartitionUnion	6646.67	root		
-├─TableReader	3323.33	root		data:Selection
-│ └─Selection	3323.33	cop[tikv]		le(test.sgc3.a, 1)
-│   └─TableFullScan	10000.00	cop[tikv]	table:sgc3, partition:p0	keep order:false, stats:pseudo
-└─TableReader	3323.33	root		data:Selection
-  └─Selection	3323.33	cop[tikv]		le(test.sgc3.a, 1)
-    └─TableFullScan	10000.00	cop[tikv]	table:sgc3, partition:p1	keep order:false, stats:pseudo
+TableReader	3323.33	root	partition:p0,p1	data:Selection
+└─Selection	3323.33	cop[tikv]		le(test.sgc3.a, 1)
+  └─TableFullScan	10000.00	cop[tikv]	table:sgc3	keep order:false, stats:pseudo
 EXPLAIN format = 'brief' SELECT * FROM sgc3 WHERE a < 7;
 id	estRows	task	access object	operator info
-PartitionUnion	23263.33	root		
-├─TableReader	3323.33	root		data:Selection
-│ └─Selection	3323.33	cop[tikv]		lt(test.sgc3.a, 7)
-│   └─TableFullScan	10000.00	cop[tikv]	table:sgc3, partition:p0	keep order:false, stats:pseudo
-├─TableReader	3323.33	root		data:Selection
-│ └─Selection	3323.33	cop[tikv]		lt(test.sgc3.a, 7)
-│   └─TableFullScan	10000.00	cop[tikv]	table:sgc3, partition:p1	keep order:false, stats:pseudo
-├─TableReader	3323.33	root		data:Selection
-│ └─Selection	3323.33	cop[tikv]		lt(test.sgc3.a, 7)
-│   └─TableFullScan	10000.00	cop[tikv]	table:sgc3, partition:p2	keep order:false, stats:pseudo
-├─TableReader	3323.33	root		data:Selection
-│ └─Selection	3323.33	cop[tikv]		lt(test.sgc3.a, 7)
-│   └─TableFullScan	10000.00	cop[tikv]	table:sgc3, partition:p3	keep order:false, stats:pseudo
-├─TableReader	3323.33	root		data:Selection
-│ └─Selection	3323.33	cop[tikv]		lt(test.sgc3.a, 7)
-│   └─TableFullScan	10000.00	cop[tikv]	table:sgc3, partition:p4	keep order:false, stats:pseudo
-├─TableReader	3323.33	root		data:Selection
-│ └─Selection	3323.33	cop[tikv]		lt(test.sgc3.a, 7)
-│   └─TableFullScan	10000.00	cop[tikv]	table:sgc3, partition:p5	keep order:false, stats:pseudo
-└─TableReader	3323.33	root		data:Selection
-  └─Selection	3323.33	cop[tikv]		lt(test.sgc3.a, 7)
-    └─TableFullScan	10000.00	cop[tikv]	table:sgc3, partition:max	keep order:false, stats:pseudo
+TableReader	3323.33	root	partition:all	data:Selection
+└─Selection	3323.33	cop[tikv]		lt(test.sgc3.a, 7)
+  └─TableFullScan	10000.00	cop[tikv]	table:sgc3	keep order:false, stats:pseudo
 DROP TABLE IF EXISTS t1;
 CREATE TABLE t1(a INT, b INT AS (a+1) VIRTUAL, c INT AS (b+1) VIRTUAL, d INT AS (c+1) VIRTUAL, KEY(b), INDEX IDX(c, d));
 INSERT INTO t1 (a) VALUES (0);
diff --git a/cmd/explaintest/r/select.result b/cmd/explaintest/r/select.result
index c69a29bd4a963..41369bffcfbbf 100644
--- a/cmd/explaintest/r/select.result
+++ b/cmd/explaintest/r/select.result
@@ -359,25 +359,17 @@ insert into th values (0,0),(1,1),(2,2),(3,3),(4,4),(5,5),(6,6),(7,7),(8,8);
 insert into th values (-1,-1),(-2,-2),(-3,-3),(-4,-4),(-5,-5),(-6,-6),(-7,-7),(-8,-8);
 desc select * from th where a=-2;
 id	estRows	task	access object	operator info
-TableReader_9	10.00	root		data:Selection_8
-└─Selection_8	10.00	cop[tikv]		eq(test.th.a, -2)
-  └─TableFullScan_7	10000.00	cop[tikv]	table:th, partition:p2	keep order:false, stats:pseudo
+TableReader_7	10.00	root	partition:p2	data:Selection_6
+└─Selection_6	10.00	cop[tikv]		eq(test.th.a, -2)
+  └─TableFullScan_5	10000.00	cop[tikv]	table:th	keep order:false, stats:pseudo
 desc select * from th;
 id	estRows	task	access object	operator info
-PartitionUnion_9	30000.00	root		
-├─TableReader_11	10000.00	root		data:TableFullScan_10
-│ └─TableFullScan_10	10000.00	cop[tikv]	table:th, partition:p0	keep order:false, stats:pseudo
-├─TableReader_13	10000.00	root		data:TableFullScan_12
-│ └─TableFullScan_12	10000.00	cop[tikv]	table:th, partition:p1	keep order:false, stats:pseudo
-└─TableReader_15	10000.00	root		data:TableFullScan_14
-  └─TableFullScan_14	10000.00	cop[tikv]	table:th, partition:p2	keep order:false, stats:pseudo
+TableReader_5	10000.00	root	partition:all	data:TableFullScan_4
+└─TableFullScan_4	10000.00	cop[tikv]	table:th	keep order:false, stats:pseudo
 desc select * from th partition (p2,p1);
 id	estRows	task	access object	operator info
-PartitionUnion_8	20000.00	root		
-├─TableReader_10	10000.00	root		data:TableFullScan_9
-│ └─TableFullScan_9	10000.00	cop[tikv]	table:th, partition:p1	keep order:false, stats:pseudo
-└─TableReader_12	10000.00	root		data:TableFullScan_11
-  └─TableFullScan_11	10000.00	cop[tikv]	table:th, partition:p2	keep order:false, stats:pseudo
+TableReader_5	10000.00	root	partition:p1,p2	data:TableFullScan_4
+└─TableFullScan_4	10000.00	cop[tikv]	table:th	keep order:false, stats:pseudo
 drop table if exists t;
 create table t(a int, b int);
 explain format = 'brief' select a != any (select a from t t2) from t t1;
diff --git a/planner/core/integration_test.go b/planner/core/integration_test.go
index a39e6c2883670..bb5849b3f4607 100644
--- a/planner/core/integration_test.go
+++ b/planner/core/integration_test.go
@@ -3163,20 +3163,6 @@ func (s *testIntegrationSuite) TestIssue22892(c *C) {
 	tk.MustQuery("select * from t2 where a not between 1 and 2;").Check(testkit.Rows("0"))
 }
 
-func (s *testIntegrationSuite) TestIssue26719(c *C) {
-	tk := testkit.NewTestKit(c, s.store)
-	tk.MustExec("use test")
-	tk.MustExec(`create table tx (a int) partition by range (a) (partition p0 values less than (10), partition p1 values less than (20))`)
-	tk.MustExec(`insert into tx values (1)`)
-	tk.MustExec("set @@tidb_partition_prune_mode='dynamic'")
-
-	tk.MustExec(`begin`)
-	tk.MustExec(`delete from tx where a in (1)`)
-	tk.MustQuery(`select * from tx PARTITION(p0)`).Check(testkit.Rows())
-	tk.MustQuery(`select * from tx`).Check(testkit.Rows())
-	tk.MustExec(`rollback`)
-}
-
 func (s *testIntegrationSerialSuite) TestPushDownProjectionForTiFlash(c *C) {
 	tk := testkit.NewTestKit(c, s.store)
 	tk.MustExec("use test")
diff --git a/sessionctx/variable/session.go b/sessionctx/variable/session.go
index 48a3dc8453454..1baa10d590da8 100644
--- a/sessionctx/variable/session.go
+++ b/sessionctx/variable/session.go
@@ -917,11 +917,6 @@ func (s *SessionVars) CheckAndGetTxnScope() string {
 
 // UseDynamicPartitionPrune indicates whether use new dynamic partition prune.
 func (s *SessionVars) UseDynamicPartitionPrune() bool {
-	if s.InTxn() {
-		// UnionScan cannot get partition table IDs in dynamic-mode, this is a quick-fix for issues/26719,
-		// please see it for more details.
-		return false
-	}
 	return PartitionPruneMode(s.PartitionPruneMode.Load()) == Dynamic
 }
 

From bff0034a8f0f2939bca066f99082315245c8b16e Mon Sep 17 00:00:00 2001
From: Song Gao <disxiaofei@163.com>
Date: Tue, 3 Aug 2021 20:27:06 +0800
Subject: [PATCH 3/8] executor: fix unstable TestStaleSelect (#26840)

---
 executor/stale_txn_test.go | 14 +-------------
 1 file changed, 1 insertion(+), 13 deletions(-)

diff --git a/executor/stale_txn_test.go b/executor/stale_txn_test.go
index b331b1f30e2a2..8c4dbcd40350f 100644
--- a/executor/stale_txn_test.go
+++ b/executor/stale_txn_test.go
@@ -865,8 +865,7 @@ func (s *testStaleTxnSuite) TestSetTransactionInfoSchema(c *C) {
 	c.Assert(tk.Se.GetInfoSchema().SchemaMetaVersion(), Equals, schemaVer3)
 }
 
-func (s *testStaleTxnSuite) TestStaleSelect(c *C) {
-	c.Skip("unstable, skip it and fix it before 20210702")
+func (s *testStaleTxnSerialSuite) TestStaleSelect(c *C) {
 	tk := testkit.NewTestKit(c, s.store)
 	tk.MustExec("use test")
 	tk.MustExec("drop table if exists t")
@@ -916,20 +915,9 @@ func (s *testStaleTxnSuite) TestStaleSelect(c *C) {
 	tk.MustExec("insert into t values (4, 5)")
 	time.Sleep(10 * time.Millisecond)
 	tk.MustQuery("execute s").Check(staleRows)
-
-	// test dynamic timestamp stale select
-	time3 := time.Now()
 	tk.MustExec("alter table t add column d int")
 	tk.MustExec("insert into t values (4, 4, 4)")
 	time.Sleep(tolerance)
-	time4 := time.Now()
-	staleRows = testkit.Rows("1 <nil>", "2 <nil>", "3 <nil>", "4 5")
-	tk.MustQuery(fmt.Sprintf("select * from t as of timestamp CURRENT_TIMESTAMP(3) - INTERVAL %d MICROSECOND", time4.Sub(time3).Microseconds())).Check(staleRows)
-
-	// test prepared dynamic timestamp stale select
-	time5 := time.Now()
-	tk.MustExec(fmt.Sprintf(`prepare v from "select * from t as of timestamp CURRENT_TIMESTAMP(3) - INTERVAL %d MICROSECOND"`, time5.Sub(time3).Microseconds()))
-	tk.MustQuery("execute v").Check(staleRows)
 
 	// test point get
 	time6 := time.Now()

From 6f913973ba4ddfc194ab333d645cafa30ee5881c Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Dani=C3=ABl=20van=20Eeden?= <git@myname.nl>
Date: Tue, 3 Aug 2021 14:45:06 +0200
Subject: [PATCH 4/8] tikv: Fix 'tatsk' typo in mockstore (#26802)

---
 store/mockstore/unistore/tikv/server.go | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/store/mockstore/unistore/tikv/server.go b/store/mockstore/unistore/tikv/server.go
index a0bc96b815dfa..58aabc0de638d 100644
--- a/store/mockstore/unistore/tikv/server.go
+++ b/store/mockstore/unistore/tikv/server.go
@@ -788,7 +788,7 @@ func (svr *Server) EstablishMPPConnectionWithStoreID(req *mpp.EstablishMPPConnec
 		}
 	}
 	if mppHandler == nil {
-		return errors.New("tatsk not found")
+		return errors.New("task not found")
 	}
 	ctx1, cancel := context.WithCancel(context.Background())
 	defer cancel()

From 9e98025f1a258c6a734c5d158d284641734dd4a4 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?=E7=8E=8B=E8=B6=85?= <cclcwangchao@hotmail.com>
Date: Tue, 3 Aug 2021 21:41:07 +0800
Subject: [PATCH 5/8] ddl: truncate local temporary table (#26466)

---
 ddl/db_integration_test.go    |  97 ++++++++++++++++++++++++++++++
 executor/ddl.go               | 108 ++++++++++++++++++++++++++--------
 infoschema/infoschema.go      |   4 +-
 infoschema/infoschema_test.go |  14 ++---
 4 files changed, 188 insertions(+), 35 deletions(-)

diff --git a/ddl/db_integration_test.go b/ddl/db_integration_test.go
index ea8e614b08cad..591df57cc7359 100644
--- a/ddl/db_integration_test.go
+++ b/ddl/db_integration_test.go
@@ -3134,3 +3134,100 @@ func (s *testIntegrationSuite3) TestDropTemporaryTable(c *C) {
 	c.Assert(err.Error(), Equals, "[schema:1051]Unknown table 'test.a_local_temp_table_9_not_exist'")
 	tk.MustQuery("select * from a_local_temp_table_8").Check(testkit.Rows())
 }
+
+func (s *testIntegrationSuite3) TestTruncateLocalTemporaryTable(c *C) {
+	tk := testkit.NewTestKit(c, s.store)
+	tk.MustExec("use test")
+	tk.MustExec("set @@tidb_enable_noop_functions = 1")
+
+	tk.MustExec("drop table if exists t1, tn")
+	tk.MustExec("create table t1 (id int)")
+	tk.MustExec("create table tn (id int)")
+	tk.MustExec("insert into t1 values(10), (11), (12)")
+	tk.MustExec("create temporary table t1 (id int primary key auto_increment)")
+	tk.MustExec("create temporary table t2 (id int primary key)")
+	tk.MustExec("create database test2")
+	tk.MustExec("create temporary table test2.t2 (id int)")
+
+	// truncate table out of txn
+	tk.MustExec("insert into t1 values(1), (2), (3)")
+	tk.MustExec("insert into t2 values(4), (5), (6)")
+	tk.MustExec("insert into test2.t2 values(7), (8), (9)")
+	tk.MustExec("truncate table t1")
+	tk.MustQuery("select * from t1").Check(testkit.Rows())
+	tk.MustExec("insert into t1 values()")
+	// auto_increment will be reset for truncate
+	tk.MustQuery("select * from t1").Check(testkit.Rows("1"))
+	tk.MustQuery("select * from t2").Check(testkit.Rows("4", "5", "6"))
+	tk.MustExec("truncate table t2")
+	tk.MustQuery("select * from t2").Check(testkit.Rows())
+	tk.MustQuery("select * from test2.t2").Check(testkit.Rows("7", "8", "9"))
+	tk.MustExec("drop table t1")
+	tk.MustQuery("select * from t1").Check(testkit.Rows("10", "11", "12"))
+	tk.MustExec("create temporary table t1 (id int primary key auto_increment)")
+
+	// truncate table with format dbName.tableName
+	tk.MustExec("insert into t2 values(4), (5), (6)")
+	tk.MustExec("insert into test2.t2 values(7), (8), (9)")
+	tk.MustExec("truncate table test2.t2")
+	tk.MustQuery("select * from test2.t2").Check(testkit.Rows())
+	tk.MustQuery("select * from t2").Check(testkit.Rows("4", "5", "6"))
+	tk.MustExec("truncate table test.t2")
+	tk.MustQuery("select * from t2").Check(testkit.Rows())
+
+	// truncate table in txn
+	tk.MustExec("insert into t1 values(1), (2), (3)")
+	tk.MustExec("insert into t2 values(4), (5), (6)")
+	tk.MustExec("begin")
+	tk.MustExec("insert into t1 values(11), (12)")
+	tk.MustExec("insert into t2 values(24), (25)")
+	tk.MustExec("delete from t1 where id=2")
+	tk.MustExec("delete from t2 where id=4")
+	tk.MustExec("truncate table t1")
+	tk.MustQuery("select * from t1").Check(testkit.Rows())
+	tk.MustExec("insert into t1 values()")
+	// auto_increment will be reset for truncate
+	tk.MustQuery("select * from t1").Check(testkit.Rows("1"))
+	tk.MustQuery("select * from t2").Check(testkit.Rows("5", "6", "24", "25"))
+
+	// since transaction already committed by truncate, so query after rollback will get same result
+	tk.MustExec("rollback")
+	tk.MustQuery("select * from t1").Check(testkit.Rows("1"))
+	tk.MustQuery("select * from t2").Check(testkit.Rows("5", "6", "24", "25"))
+
+	// truncate a temporary table will not effect the normal table with the same name
+	tk.MustExec("drop table t1")
+	tk.MustQuery("select * from t1").Check(testkit.Rows("10", "11", "12"))
+	tk.MustExec("create temporary table t1 (id int primary key auto_increment)")
+
+	// truncate temporary table will clear session data
+	localTemporaryTables := tk.Se.GetSessionVars().LocalTemporaryTables.(*infoschema.LocalTemporaryTables)
+	tb1, exist := localTemporaryTables.TableByName(model.NewCIStr("test"), model.NewCIStr("t1"))
+	tbl1Info := tb1.Meta()
+	tablePrefix := tablecodec.EncodeTablePrefix(tbl1Info.ID)
+	endTablePrefix := tablecodec.EncodeTablePrefix(tbl1Info.ID + 1)
+	c.Assert(exist, IsTrue)
+	tk.MustExec("insert into t1 values(1), (2), (3)")
+	tk.MustExec("begin")
+	tk.MustExec("insert into t1 values(5), (6), (7)")
+	tk.MustExec("truncate table t1")
+	iter, err := tk.Se.GetSessionVars().TemporaryTableData.Iter(tablePrefix, endTablePrefix)
+	c.Assert(err, IsNil)
+	for iter.Valid() {
+		key := iter.Key()
+		if !bytes.HasPrefix(key, tablePrefix) {
+			break
+		}
+		value := iter.Value()
+		c.Assert(len(value), Equals, 0)
+		_ = iter.Next()
+	}
+	c.Assert(iter.Valid(), IsFalse)
+
+	// truncate after drop database should be successful
+	tk.MustExec("create temporary table test2.t3 (id int)")
+	tk.MustExec("insert into test2.t3 values(1)")
+	tk.MustExec("drop database test2")
+	tk.MustExec("truncate table test2.t3")
+	tk.MustQuery("select * from test2.t3").Check(testkit.Rows())
+}
diff --git a/executor/ddl.go b/executor/ddl.go
index 99de4f6db9f5c..8679ae150a296 100644
--- a/executor/ddl.go
+++ b/executor/ddl.go
@@ -33,6 +33,7 @@ import (
 	"github.com/pingcap/tidb/meta/autoid"
 	"github.com/pingcap/tidb/planner/core"
 	"github.com/pingcap/tidb/sessionctx/variable"
+	"github.com/pingcap/tidb/table"
 	"github.com/pingcap/tidb/table/tables"
 	"github.com/pingcap/tidb/tablecodec"
 	"github.com/pingcap/tidb/util/admin"
@@ -104,6 +105,27 @@ func deleteTemporaryTableRecords(memData kv.MemBuffer, tblID int64) error {
 	return nil
 }
 
+func (e *DDLExec) getLocalTemporaryTables() *infoschema.LocalTemporaryTables {
+	tempTables := e.ctx.GetSessionVars().LocalTemporaryTables
+	if tempTables != nil {
+		return tempTables.(*infoschema.LocalTemporaryTables)
+	}
+	return nil
+}
+
+func (e *DDLExec) getLocalTemporaryTable(schema model.CIStr, table model.CIStr) (table.Table, bool) {
+	tbl, err := e.ctx.GetInfoSchema().(infoschema.InfoSchema).TableByName(schema, table)
+	if infoschema.ErrTableNotExists.Equal(err) {
+		return nil, false
+	}
+
+	if tbl.Meta().TempTableType != model.TempTableLocal {
+		return nil, false
+	}
+
+	return tbl, true
+}
+
 // Next implements the Executor Next interface.
 func (e *DDLExec) Next(ctx context.Context, req *chunk.Chunk) (err error) {
 	if e.done {
@@ -218,10 +240,40 @@ func (e *DDLExec) Next(ctx context.Context, req *chunk.Chunk) (err error) {
 
 func (e *DDLExec) executeTruncateTable(s *ast.TruncateTableStmt) error {
 	ident := ast.Ident{Schema: s.Table.Schema, Name: s.Table.Name}
+	if _, exist := e.getLocalTemporaryTable(s.Table.Schema, s.Table.Name); exist {
+		return e.executeTruncateLocalTemporaryTable(s)
+	}
 	err := domain.GetDomain(e.ctx).DDL().TruncateTable(e.ctx, ident)
 	return err
 }
 
+func (e *DDLExec) executeTruncateLocalTemporaryTable(s *ast.TruncateTableStmt) error {
+	tbl, exists := e.getLocalTemporaryTable(s.Table.Schema, s.Table.Name)
+	if !exists {
+		return infoschema.ErrTableNotExists.GenWithStackByArgs(s.Table.Schema, s.Table.Name)
+	}
+
+	tblInfo := tbl.Meta()
+
+	newTbl, err := e.newTemporaryTableFromTableInfo(tblInfo.Clone())
+	if err != nil {
+		return err
+	}
+
+	localTempTables := e.getLocalTemporaryTables()
+	localTempTables.RemoveTable(s.Table.Schema, s.Table.Name)
+	if err := localTempTables.AddTable(s.Table.Schema, newTbl); err != nil {
+		return err
+	}
+
+	err = deleteTemporaryTableRecords(e.ctx.GetSessionVars().TemporaryTableData, tblInfo.ID)
+	if err != nil {
+		return err
+	}
+
+	return nil
+}
+
 func (e *DDLExec) executeRenameTable(s *ast.RenameTableStmt) error {
 	isAlterTable := false
 	var err error
@@ -286,31 +338,7 @@ func (e *DDLExec) createSessionTemporaryTable(s *ast.CreateTableStmt) error {
 		return err
 	}
 
-	dom := domain.GetDomain(e.ctx)
-	// Local temporary table uses a real table ID.
-	// We could mock a table ID, but the mocked ID might be identical to an existing
-	// real table, and then we'll get into trouble.
-	err = kv.RunInNewTxn(context.Background(), dom.Store(), true, func(ctx context.Context, txn kv.Transaction) error {
-		m := meta.NewMeta(txn)
-		tblID, err := m.GenGlobalID()
-		if err != nil {
-			return errors.Trace(err)
-		}
-		tbInfo.ID = tblID
-		tbInfo.State = model.StatePublic
-		return nil
-	})
-	if err != nil {
-		return err
-	}
-
-	// AutoID is allocated in mocked..
-	alloc := autoid.NewAllocatorFromTempTblInfo(tbInfo)
-	allocs := make([]autoid.Allocator, 0, 1)
-	if alloc != nil {
-		allocs = append(allocs, alloc)
-	}
-	tbl, err := tables.TableFromMeta(allocs, tbInfo)
+	tbl, err := e.newTemporaryTableFromTableInfo(tbInfo)
 	if err != nil {
 		return err
 	}
@@ -333,7 +361,7 @@ func (e *DDLExec) createSessionTemporaryTable(s *ast.CreateTableStmt) error {
 		sessVars.TemporaryTableData = bufferTxn.GetMemBuffer()
 	}
 
-	err = localTempTables.AddTable(dbInfo, tbl)
+	err = localTempTables.AddTable(dbInfo.Name, tbl)
 
 	if err != nil && s.IfNotExists && infoschema.ErrTableExists.Equal(err) {
 		e.ctx.GetSessionVars().StmtCtx.AppendNote(err)
@@ -343,6 +371,34 @@ func (e *DDLExec) createSessionTemporaryTable(s *ast.CreateTableStmt) error {
 	return err
 }
 
+func (e *DDLExec) newTemporaryTableFromTableInfo(tbInfo *model.TableInfo) (table.Table, error) {
+	dom := domain.GetDomain(e.ctx)
+	// Local temporary table uses a real table ID.
+	// We could mock a table ID, but the mocked ID might be identical to an existing
+	// real table, and then we'll get into trouble.
+	err := kv.RunInNewTxn(context.Background(), dom.Store(), true, func(ctx context.Context, txn kv.Transaction) error {
+		m := meta.NewMeta(txn)
+		tblID, err := m.GenGlobalID()
+		if err != nil {
+			return errors.Trace(err)
+		}
+		tbInfo.ID = tblID
+		tbInfo.State = model.StatePublic
+		return nil
+	})
+	if err != nil {
+		return nil, err
+	}
+
+	// AutoID is allocated in mocked..
+	alloc := autoid.NewAllocatorFromTempTblInfo(tbInfo)
+	allocs := make([]autoid.Allocator, 0, 1)
+	if alloc != nil {
+		allocs = append(allocs, alloc)
+	}
+	return tables.TableFromMeta(allocs, tbInfo)
+}
+
 func (e *DDLExec) executeCreateView(s *ast.CreateViewStmt) error {
 	ret := &core.PreprocessorReturn{}
 	err := core.Preprocess(e.ctx, s.Select, core.WithPreprocessorReturn(ret))
diff --git a/infoschema/infoschema.go b/infoschema/infoschema.go
index 1ec089bc8413f..41b5b2c2f6e7c 100644
--- a/infoschema/infoschema.go
+++ b/infoschema/infoschema.go
@@ -445,8 +445,8 @@ func (is *LocalTemporaryTables) TableByID(id int64) (tbl table.Table, ok bool) {
 }
 
 // AddTable add a table
-func (is *LocalTemporaryTables) AddTable(schema *model.DBInfo, tbl table.Table) error {
-	schemaTables := is.ensureSchema(schema.Name)
+func (is *LocalTemporaryTables) AddTable(schema model.CIStr, tbl table.Table) error {
+	schemaTables := is.ensureSchema(schema)
 
 	tblMeta := tbl.Meta()
 	if _, ok := schemaTables.tables[tblMeta.Name.L]; ok {
diff --git a/infoschema/infoschema_test.go b/infoschema/infoschema_test.go
index 302a9d6f9f472..06eee146d438e 100644
--- a/infoschema/infoschema_test.go
+++ b/infoschema/infoschema_test.go
@@ -511,7 +511,7 @@ func (*testSuite) TestLocalTemporaryTables(c *C) {
 	}
 
 	for _, p := range prepareTables {
-		err = sc.AddTable(p.db, p.tb)
+		err = sc.AddTable(p.db.Name, p.tb)
 		c.Assert(err, IsNil)
 	}
 
@@ -543,16 +543,16 @@ func (*testSuite) TestLocalTemporaryTables(c *C) {
 	}
 
 	// test add dup table
-	err = sc.AddTable(db1, tb11)
+	err = sc.AddTable(db1.Name, tb11)
 	c.Assert(infoschema.ErrTableExists.Equal(err), IsTrue)
-	err = sc.AddTable(db1b, tb15)
+	err = sc.AddTable(db1b.Name, tb15)
 	c.Assert(infoschema.ErrTableExists.Equal(err), IsTrue)
-	err = sc.AddTable(db1b, tb11)
+	err = sc.AddTable(db1b.Name, tb11)
 	c.Assert(infoschema.ErrTableExists.Equal(err), IsTrue)
 	db1c := createNewSchemaInfo("db1")
-	err = sc.AddTable(db1c, createNewTable(db1c.ID, "tb1", model.TempTableLocal))
+	err = sc.AddTable(db1c.Name, createNewTable(db1c.ID, "tb1", model.TempTableLocal))
 	c.Assert(infoschema.ErrTableExists.Equal(err), IsTrue)
-	err = sc.AddTable(db1b, tb11)
+	err = sc.AddTable(db1b.Name, tb11)
 	c.Assert(infoschema.ErrTableExists.Equal(err), IsTrue)
 
 	// failed add has no effect
@@ -598,7 +598,7 @@ func (*testSuite) TestLocalTemporaryTables(c *C) {
 		LocalTemporaryTables: sc,
 	}
 
-	err = sc.AddTable(dbTest, tmpTbTestA)
+	err = sc.AddTable(dbTest.Name, tmpTbTestA)
 	c.Assert(err, IsNil)
 
 	// test TableByName

From 5e05922de6a253859cfbfe19356de8a2e2db39da Mon Sep 17 00:00:00 2001
From: ateb14 <85857627+ateb14@users.noreply.github.com>
Date: Wed, 4 Aug 2021 00:01:07 +0800
Subject: [PATCH 6/8] util/localpool: migrate test-infra to testify for
 util/localpool pkg #26183 (#26768)

---
 util/localpool/localpool_test.go | 17 ++++-------------
 util/localpool/main_test.go      | 26 ++++++++++++++++++++++++++
 2 files changed, 30 insertions(+), 13 deletions(-)
 create mode 100644 util/localpool/main_test.go

diff --git a/util/localpool/localpool_test.go b/util/localpool/localpool_test.go
index 77c4f92e80171..938e30b5d6284 100644
--- a/util/localpool/localpool_test.go
+++ b/util/localpool/localpool_test.go
@@ -21,7 +21,7 @@ import (
 	"sync"
 	"testing"
 
-	. "github.com/pingcap/check"
+	"github.com/stretchr/testify/require"
 )
 
 type Obj struct {
@@ -29,16 +29,7 @@ type Obj struct {
 	val int64 // nolint:structcheck // Dummy field to make it non-empty.
 }
 
-func TestT(t *testing.T) {
-	TestingT(t)
-}
-
-var _ = Suite(&testPoolSuite{})
-
-type testPoolSuite struct {
-}
-
-func (s *testPoolSuite) TestPool(c *C) {
+func TestPool(t *testing.T) {
 	numWorkers := runtime.GOMAXPROCS(0)
 	wg := new(sync.WaitGroup)
 	wg.Add(numWorkers)
@@ -62,8 +53,8 @@ func (s *testPoolSuite) TestPool(c *C) {
 		putHit += slot.putHit
 		putMiss += slot.putMiss
 	}
-	c.Assert(getHit, Greater, getMiss)
-	c.Assert(putHit, Greater, putMiss)
+	require.Greater(t, getHit, getMiss)
+	require.Greater(t, putHit, putMiss)
 }
 
 func GetAndPut(pool *LocalPool) {
diff --git a/util/localpool/main_test.go b/util/localpool/main_test.go
new file mode 100644
index 0000000000000..9b1110e85c45d
--- /dev/null
+++ b/util/localpool/main_test.go
@@ -0,0 +1,26 @@
+// Copyright 2021 PingCAP, Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package localpool
+
+import (
+	"testing"
+
+	"github.com/pingcap/tidb/util/testbridge"
+	"go.uber.org/goleak"
+)
+
+func TestMain(m *testing.M) {
+	testbridge.WorkaroundGoCheckFlags()
+	goleak.VerifyTestMain(m)
+}

From 58f10fecee10c98ed7236c4770de4b8d0d6b04f5 Mon Sep 17 00:00:00 2001
From: unconsolable <chenzhipeng2012@gmail.com>
Date: Wed, 4 Aug 2021 02:03:06 +0800
Subject: [PATCH 7/8] plugin: migrate test-infra to testify (#26769)

---
 plugin/const_test.go  |   8 ++--
 plugin/helper_test.go |  28 +++++-------
 plugin/main_test.go   |  26 +++++++++++
 plugin/plugin_test.go | 102 ++++++++++++++----------------------------
 plugin/spi_test.go    |  11 +++--
 5 files changed, 81 insertions(+), 94 deletions(-)
 create mode 100644 plugin/main_test.go

diff --git a/plugin/const_test.go b/plugin/const_test.go
index dd366b41d2c4e..f75ca4d4138b7 100644
--- a/plugin/const_test.go
+++ b/plugin/const_test.go
@@ -16,9 +16,12 @@ package plugin
 import (
 	"fmt"
 	"testing"
+
+	"github.com/stretchr/testify/require"
 )
 
 func TestConstToString(t *testing.T) {
+	t.Parallel()
 	kinds := map[fmt.Stringer]string{
 		Audit:                     "Audit",
 		Authentication:            "Authentication",
@@ -32,11 +35,10 @@ func TestConstToString(t *testing.T) {
 		Disconnect:                "Disconnect",
 		ChangeUser:                "ChangeUser",
 		PreAuth:                   "PreAuth",
+		Reject:                    "Reject",
 		ConnectionEvent(byte(15)): "",
 	}
 	for key, value := range kinds {
-		if key.String() != value {
-			t.Errorf("kind %s != %s", key.String(), kinds)
-		}
+		require.Equal(t, value, key.String())
 	}
 }
diff --git a/plugin/helper_test.go b/plugin/helper_test.go
index 1bb3fc71420ec..d0701ea789099 100644
--- a/plugin/helper_test.go
+++ b/plugin/helper_test.go
@@ -13,42 +13,38 @@
 
 package plugin
 
-import "testing"
+import (
+	"testing"
+
+	"github.com/stretchr/testify/require"
+)
 
 func TestPluginDeclare(t *testing.T) {
+	t.Parallel()
 	auditRaw := &AuditManifest{Manifest: Manifest{}}
 	auditExport := ExportManifest(auditRaw)
 	audit2 := DeclareAuditManifest(auditExport)
-	if audit2 != auditRaw {
-		t.Errorf("declare audit fail")
-	}
+	require.Equal(t, auditRaw, audit2)
 
 	authRaw := &AuthenticationManifest{Manifest: Manifest{}}
 	authExport := ExportManifest(authRaw)
 	auth2 := DeclareAuthenticationManifest(authExport)
-	if auth2 != authRaw {
-		t.Errorf("declare auth fail")
-	}
+	require.Equal(t, authRaw, auth2)
 
 	schemaRaw := &SchemaManifest{Manifest: Manifest{}}
 	schemaExport := ExportManifest(schemaRaw)
 	schema2 := DeclareSchemaManifest(schemaExport)
-	if schema2 != schemaRaw {
-		t.Errorf("declare schema fail")
-	}
+	require.Equal(t, schemaRaw, schema2)
 
 	daemonRaw := &DaemonManifest{Manifest: Manifest{}}
 	daemonExport := ExportManifest(daemonRaw)
 	daemon2 := DeclareDaemonManifest(daemonExport)
-	if daemon2 != daemonRaw {
-		t.Errorf("declare daemon fail")
-	}
+	require.Equal(t, daemonRaw, daemon2)
 }
 
 func TestDecode(t *testing.T) {
+	t.Parallel()
 	failID := ID("fail")
 	_, _, err := failID.Decode()
-	if err == nil {
-		t.Errorf("'fail' should not decode success")
-	}
+	require.Error(t, err)
 }
diff --git a/plugin/main_test.go b/plugin/main_test.go
new file mode 100644
index 0000000000000..108caec196390
--- /dev/null
+++ b/plugin/main_test.go
@@ -0,0 +1,26 @@
+// Copyright 2021 PingCAP, Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package plugin
+
+import (
+	"testing"
+
+	"github.com/pingcap/tidb/util/testbridge"
+	"go.uber.org/goleak"
+)
+
+func TestMain(m *testing.M) {
+	testbridge.WorkaroundGoCheckFlags()
+	goleak.VerifyTestMain(m)
+}
diff --git a/plugin/plugin_test.go b/plugin/plugin_test.go
index 6b1ec9e2c3441..f07e9d1fcac8a 100644
--- a/plugin/plugin_test.go
+++ b/plugin/plugin_test.go
@@ -19,14 +19,10 @@ import (
 	"strconv"
 	"testing"
 
-	"github.com/pingcap/check"
 	"github.com/pingcap/tidb/sessionctx/variable"
+	"github.com/stretchr/testify/require"
 )
 
-func TestT(t *testing.T) {
-	check.TestingT(t)
-}
-
 func TestLoadPluginSuccess(t *testing.T) {
 	ctx := context.Background()
 
@@ -41,7 +37,7 @@ func TestLoadPluginSuccess(t *testing.T) {
 	}
 
 	// setup load test hook.
-	testHook = &struct{ loadOne loadFn }{loadOne: func(plugin *Plugin, dir string, pluginID ID) (manifest func() *Manifest, err error) {
+	SetTestHook(func(plugin *Plugin, dir string, pluginID ID) (manifest func() *Manifest, err error) {
 		return func() *Manifest {
 			m := &AuditManifest{
 				Manifest: Manifest{
@@ -63,55 +59,40 @@ func TestLoadPluginSuccess(t *testing.T) {
 			}
 			return ExportManifest(m)
 		}, nil
-	}}
+	})
 	defer func() {
 		testHook = nil
 	}()
 
 	// trigger load.
 	err := Load(ctx, cfg)
-	if err != nil {
-		t.Errorf("load plugin [%s] fail", pluginSign)
-	}
+	require.NoError(t, err)
 
 	err = Init(ctx, cfg)
-	if err != nil {
-		t.Errorf("init plugin [%s] fail", pluginSign)
-	}
+	require.NoError(t, err)
 
 	// load all.
 	ps := GetAll()
-	if len(ps) != 1 {
-		t.Errorf("loaded plugins is empty")
-	}
+	require.Len(t, ps, 1)
+	require.True(t, IsEnable(Authentication))
 
 	// find plugin by type and name
 	p := Get(Authentication, "tplugin")
-	if p == nil {
-		t.Errorf("tplugin can not be load")
-	}
+	require.NotNil(t, p)
 	p = Get(Authentication, "tplugin2")
-	if p != nil {
-		t.Errorf("found miss plugin")
-	}
+	require.Nil(t, p)
 	p = getByName("tplugin")
-	if p == nil {
-		t.Errorf("can not find miss plugin")
-	}
+	require.NotNil(t, p)
 
 	// foreach plugin
 	err = ForeachPlugin(Authentication, func(plugin *Plugin) error {
 		return nil
 	})
-	if err != nil {
-		t.Errorf("foreach error %v", err)
-	}
+	require.NoError(t, err)
 	err = ForeachPlugin(Authentication, func(plugin *Plugin) error {
 		return io.EOF
 	})
-	if err != io.EOF {
-		t.Errorf("foreach should return EOF error")
-	}
+	require.Equal(t, io.EOF, err)
 
 	Shutdown(ctx)
 }
@@ -131,7 +112,7 @@ func TestLoadPluginSkipError(t *testing.T) {
 	}
 
 	// setup load test hook.
-	testHook = &struct{ loadOne loadFn }{loadOne: func(plugin *Plugin, dir string, pluginID ID) (manifest func() *Manifest, err error) {
+	SetTestHook(func(plugin *Plugin, dir string, pluginID ID) (manifest func() *Manifest, err error) {
 		return func() *Manifest {
 			m := &AuditManifest{
 				Manifest: Manifest{
@@ -153,58 +134,41 @@ func TestLoadPluginSkipError(t *testing.T) {
 			}
 			return ExportManifest(m)
 		}, nil
-	}}
+	})
 	defer func() {
 		testHook = nil
 	}()
 
 	// trigger load.
 	err := Load(ctx, cfg)
-	if err != nil {
-		t.Errorf("load plugin [%s] fail %v", pluginSign, err)
-	}
+	require.NoError(t, err)
 
 	err = Init(ctx, cfg)
-	if err != nil {
-		t.Errorf("init plugin [%s] fail", pluginSign)
-	}
+	require.NoError(t, err)
+	require.False(t, IsEnable(Audit))
 
 	// load all.
 	ps := GetAll()
-	if len(ps) != 1 {
-		t.Errorf("loaded plugins is empty")
-	}
+	require.Len(t, ps, 1)
 
 	// find plugin by type and name
 	p := Get(Audit, "tplugin")
-	if p == nil {
-		t.Errorf("tplugin can not be load")
-	}
+	require.NotNil(t, p)
 	p = Get(Audit, "tplugin2")
-	if p != nil {
-		t.Errorf("found miss plugin")
-	}
+	require.Nil(t, p)
 	p = getByName("tplugin")
-	if p == nil {
-		t.Errorf("can not find miss plugin")
-	}
+	require.NotNil(t, p)
 	p = getByName("not exists")
-	if p != nil {
-		t.Errorf("got not exists plugin")
-	}
+	require.Nil(t, p)
 
 	// foreach plugin
 	readyCount := 0
-	err = ForeachPlugin(Authentication, func(plugin *Plugin) error {
+	err = ForeachPlugin(Audit, func(plugin *Plugin) error {
 		readyCount++
 		return nil
 	})
-	if err != nil {
-		t.Errorf("foreach meet error %v", err)
-	}
-	if readyCount != 0 {
-		t.Errorf("validate fail can be load but no ready")
-	}
+	require.NoError(t, err)
+	require.Equal(t, 0, readyCount)
 
 	Shutdown(ctx)
 }
@@ -224,7 +188,7 @@ func TestLoadFail(t *testing.T) {
 	}
 
 	// setup load test hook.
-	testHook = &struct{ loadOne loadFn }{loadOne: func(plugin *Plugin, dir string, pluginID ID) (manifest func() *Manifest, err error) {
+	SetTestHook(func(plugin *Plugin, dir string, pluginID ID) (manifest func() *Manifest, err error) {
 		return func() *Manifest {
 			m := &AuditManifest{
 				Manifest: Manifest{
@@ -246,15 +210,13 @@ func TestLoadFail(t *testing.T) {
 			}
 			return ExportManifest(m)
 		}, nil
-	}}
+	})
 	defer func() {
 		testHook = nil
 	}()
 
 	err := Load(ctx, cfg)
-	if err == nil {
-		t.Errorf("load plugin should fail")
-	}
+	require.Error(t, err)
 }
 
 func TestPluginsClone(t *testing.T) {
@@ -273,7 +235,9 @@ func TestPluginsClone(t *testing.T) {
 	as := ps.plugins[Audit]
 	ps.plugins[Audit] = append(as, Plugin{})
 
-	if len(cps.plugins) != 1 || len(cps.plugins[Audit]) != 1 || len(cps.versions) != 1 || len(cps.dyingPlugins) != 1 {
-		t.Errorf("clone plugins failure")
-	}
+	require.Len(t, cps.plugins, 1)
+	require.Len(t, cps.plugins[Audit], 1)
+	require.Len(t, cps.versions, 1)
+	require.Equal(t, uint16(1), cps.versions["whitelist"])
+	require.Len(t, cps.dyingPlugins, 1)
 }
diff --git a/plugin/spi_test.go b/plugin/spi_test.go
index e619f9492a1bd..08ca1c0c28a85 100644
--- a/plugin/spi_test.go
+++ b/plugin/spi_test.go
@@ -19,9 +19,11 @@ import (
 
 	"github.com/pingcap/tidb/plugin"
 	"github.com/pingcap/tidb/sessionctx/variable"
+	"github.com/stretchr/testify/require"
 )
 
 func TestExportManifest(t *testing.T) {
+	t.Parallel()
 	callRecorder := struct {
 		OnInitCalled      bool
 		NotifyEventCalled bool
@@ -42,12 +44,9 @@ func TestExportManifest(t *testing.T) {
 	}
 	exported := plugin.ExportManifest(manifest)
 	err := exported.OnInit(context.Background(), exported)
-	if err != nil {
-		t.Fatal(err)
-	}
+	require.NoError(t, err)
 	audit := plugin.DeclareAuditManifest(exported)
 	audit.OnGeneralEvent(context.Background(), nil, plugin.Log, "QUERY")
-	if !callRecorder.NotifyEventCalled || !callRecorder.OnInitCalled {
-		t.Fatalf("export test failure")
-	}
+	require.True(t, callRecorder.NotifyEventCalled)
+	require.True(t, callRecorder.OnInitCalled)
 }

From 31403ad0a11e0d965c09003bfaeafe157a70eade Mon Sep 17 00:00:00 2001
From: guo-shaoge <shaoge1994@163.com>
Date: Wed, 4 Aug 2021 10:05:06 +0800
Subject: [PATCH 8/8] executor: fix unexpected behavior when casting invalid
 string to date (#26784)

---
 executor/insert_test.go | 17 +++++++++++++++++
 table/column.go         |  3 ++-
 2 files changed, 19 insertions(+), 1 deletion(-)

diff --git a/executor/insert_test.go b/executor/insert_test.go
index 17f1d3b05245a..bde17e6c1218b 100644
--- a/executor/insert_test.go
+++ b/executor/insert_test.go
@@ -1726,3 +1726,20 @@ func (s *testSuite13) TestGlobalTempTableParallel(c *C) {
 	}
 	wg.Wait()
 }
+
+func (s *testSuite13) TestIssue26762(c *C) {
+	tk := testkit.NewTestKit(c, s.store)
+	tk.MustExec(`use test`)
+	tk.MustExec("drop table if exists t1;")
+	tk.MustExec("create table t1(c1 date);")
+	_, err := tk.Exec("insert into t1 values('2020-02-31');")
+	c.Assert(err.Error(), Equals, `[table:1292]Incorrect date value: '2020-02-31' for column 'c1' at row 1`)
+
+	tk.MustExec("set @@sql_mode='ALLOW_INVALID_DATES';")
+	tk.MustExec("insert into t1 values('2020-02-31');")
+	tk.MustQuery("select * from t1").Check(testkit.Rows("2020-02-31"))
+
+	tk.MustExec("set @@sql_mode='STRICT_TRANS_TABLES';")
+	_, err = tk.Exec("insert into t1 values('2020-02-31');")
+	c.Assert(err.Error(), Equals, `[table:1292]Incorrect date value: '2020-02-31' for column 'c1' at row 1`)
+}
diff --git a/table/column.go b/table/column.go
index 433b3ee4a2723..843ab857a1fd2 100644
--- a/table/column.go
+++ b/table/column.go
@@ -246,7 +246,8 @@ func handleZeroDatetime(ctx sessionctx.Context, col *model.ColumnInfo, casted ty
 		return types.NewDatum(zeroV), true, nil
 	} else if tm.IsZero() || tm.InvalidZero() {
 		if tm.IsZero() {
-			if !mode.HasNoZeroDateMode() {
+			// Don't care NoZeroDate mode if time val is invalid.
+			if !tmIsInvalid && !mode.HasNoZeroDateMode() {
 				return types.NewDatum(zeroV), true, nil
 			}
 		} else if tm.InvalidZero() {