-
Notifications
You must be signed in to change notification settings - Fork 28
/
Copy pathCount_of_Range_Sum.cpp
107 lines (87 loc) · 3.2 KB
/
Count_of_Range_Sum.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
class Solution {
struct SegmentNode {
long lower;
long upper;
int count;
SegmentNode() {}
SegmentNode(long l, long r): lower(l), upper(r), count(0) {}
};
class SegmentTree{
vector<SegmentNode> segmentNodes;
int n;
void build(int node, int left, int right, vector<long>& nums) {
segmentNodes[node] = SegmentNode(nums[left], nums[right]);
if(left == right) {
return;
}
int leftNode = node << 1;
int rightNode = leftNode | 1;
int mid = left + (right - left) / 2;
build(leftNode, left, mid, nums);
build(rightNode, mid + 1, right, nums);
}
void update(int node, int left, int right, long val) {
if(val < segmentNodes[node].lower or val > segmentNodes[node].upper) {
return;
}
segmentNodes[node].count++;
if(left == right) {
return;
}
int leftNode = node << 1;
int rightNode = leftNode | 1;
int mid = left + (right - left) / 2;
update(leftNode, left, mid, val);
update(rightNode, mid + 1, right, val);
}
int query(int node, int left, int right, const long lower, const long upper) {
if(upper < segmentNodes[node].lower or lower > segmentNodes[node].upper) {
return 0;
}
if(segmentNodes[node].lower >= lower and segmentNodes[node].upper <= upper) {
return segmentNodes[node].count;
}
int leftNode = node << 1;
int rightNode = leftNode | 1;
int mid = left + (right - left) / 2;
return query(leftNode, left, mid, lower, upper) + query(rightNode, mid + 1, right, lower, upper);
}
public:
void init(vector<long>& nums) {
n = (int)nums.size();
int N = 2 * pow(2, ceil(log((double)n) / log(2.0)) + 1) - 1;
segmentNodes.resize(N);
build(1, 0, n - 1, nums);
}
void update(long val) {
update(1, 0, n - 1, val);
}
int query(long lower, long upper) {
return query(1, 0, n - 1, lower, upper);
}
};
public:
int countRangeSum(vector<int>& nums, int lower, int upper) {
if(nums.empty()) return 0;
int n = (int)nums.size();
vector<long> sums;
set<long> values;
long sum = 0L;
for(int i = 0; i < (int)nums.size(); ++i) {
sum += nums[i];
values.insert(sum);
}
for(auto it = values.begin(); it != values.end(); ++it) {
sums.push_back(*it);
}
int result = 0;
SegmentTree segmentTree;
segmentTree.init(sums);
for(int i = n - 1; i >= 0; --i) {
segmentTree.update(sum);
sum -= nums[i];
result += segmentTree.query(lower + sum, upper + sum);
}
return result;
}
};