-
Notifications
You must be signed in to change notification settings - Fork 8
/
Copy pathinsert.go
128 lines (108 loc) · 2.96 KB
/
insert.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
package util
import (
"context"
"fmt"
"sync"
"github.com/bmeg/grip/gdbi"
"github.com/bmeg/grip/log"
multierror "github.com/hashicorp/go-multierror"
"golang.org/x/sync/semaphore"
)
// StreamBatch a stream of inputs and loads them into the graph
// This function assumes incoming stream is GraphElemnts from a single graph
func StreamBatch(stream <-chan *gdbi.GraphElement, batchSize int, graph string, vertexAdd func(<-chan *gdbi.Vertex, int) error, edgeAdd func(<-chan *gdbi.Edge, int) error) error {
var bulkErr *multierror.Error
vertCount := 0
edgeCount := 0
wg := &sync.WaitGroup{}
vertexChan := make(chan *gdbi.Vertex, batchSize)
edgeChan := make(chan *gdbi.Edge, batchSize)
sem := semaphore.NewWeighted(int64(batchSize * 2))
wg.Add(2)
go func() {
defer wg.Done()
if err := vertexAdd(vertexChan, batchSize); err != nil {
bulkErr = multierror.Append(bulkErr, err)
}
}()
go func() {
defer wg.Done()
if err := edgeAdd(edgeChan, batchSize); err != nil {
bulkErr = multierror.Append(bulkErr, err)
}
}()
vertexBatch := make([]*gdbi.Vertex, 0, batchSize)
edgeBatch := make([]*gdbi.Edge, 0, batchSize)
for element := range stream {
if element.Graph != graph {
bulkErr = multierror.Append(
bulkErr,
fmt.Errorf("unexpected graph reference: %s != %s", element.Graph, graph),
)
continue
}
if element.Vertex != nil {
vertex := element.Vertex
if err := vertex.Validate(); err != nil {
bulkErr = multierror.Append(
bulkErr,
fmt.Errorf("vertex validation failed: %v", err),
)
continue
}
sem.Acquire(context.Background(), 1)
vertexBatch = append(vertexBatch, vertex)
vertCount++
if len(vertexBatch) >= batchSize {
batchSizeToRelease := len(vertexBatch)
for _, v := range vertexBatch {
vertexChan <- v
}
vertexBatch = make([]*gdbi.Vertex, 0, batchSize)
sem.Release(int64(batchSizeToRelease))
}
} else if element.Edge != nil {
edge := element.Edge
if edge.ID == "" {
edge.ID = UUID()
}
if err := edge.Validate(); err != nil {
bulkErr = multierror.Append(
bulkErr,
fmt.Errorf("edge validation failed: %v", err),
)
continue
}
sem.Acquire(context.Background(), 1)
edgeBatch = append(edgeBatch, edge)
edgeCount++
if len(edgeBatch) >= batchSize {
batchSizeToRelease := len(edgeBatch)
for _, e := range edgeBatch {
edgeChan <- e
}
edgeBatch = make([]*gdbi.Edge, 0, batchSize)
sem.Release(int64(batchSizeToRelease))
}
}
}
// Send remaining vertices and edges in the batch
for _, v := range vertexBatch {
vertexChan <- v
}
for _, e := range edgeBatch {
edgeChan <- e
}
// Close channels after all data is sent
close(vertexChan)
close(edgeChan)
wg.Wait()
sem.Release(int64(len(vertexBatch) + len(edgeBatch)))
if vertCount != 0 {
log.Debugf("%d vertices streamed to BulkAdd", vertCount)
}
if edgeCount != 0 {
log.Debugf("%d edges streamed to BulkAdd", edgeCount)
}
return bulkErr.ErrorOrNil()
}