Skip to content

Commit

Permalink
Merge branch 'upstream1' into develop
Browse files Browse the repository at this point in the history
  • Loading branch information
kentnsw committed Aug 25, 2022
2 parents 55d7cca + f62f6b9 commit 39aec71
Show file tree
Hide file tree
Showing 34 changed files with 2,432 additions and 234 deletions.
137 changes: 89 additions & 48 deletions c7n/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@
"""
import pickle # nosec nosemgrep

from datetime import datetime, timedelta
import os
import logging
import time
import sqlite3

log = logging.getLogger('custodian.cache')

Expand All @@ -30,12 +31,11 @@ def factory(config):
if not CACHE_NOTIFY:
log.debug("Using in-memory cache")
CACHE_NOTIFY = True
return InMemoryCache()
return InMemoryCache(config)
return SqlKvCache(config)

return FileCacheManager(config)


class NullCache:
class Cache:

def __init__(self, config):
self.config = config
Expand All @@ -55,80 +55,121 @@ def save(self, key, data):
def size(self):
return 0

def close(self):
pass

def __enter__(self):
self.load()
return self

def __exit__(self, exc_type, exc_val, exc_tb):
self.close()


class InMemoryCache:
class NullCache(Cache):
pass


class InMemoryCache(Cache):
# Running in a temporary environment, so keep as a cache.

__shared_state = {}

def __init__(self):
def __init__(self, config):
super().__init__(config)
self.data = self.__shared_state

def load(self):
return True

def get(self, key):
return self.data.get(pickle.dumps(key)) # nosemgrep
return self.data.get(encode(key))

def haskey(self, key):
return pickle.dumps(key) in self.data # nosemgrep

def save(self, key, data):
self.data[pickle.dumps(key)] = data # nosemgrep
self.data[encode(key)] = data

def size(self):
return sum(map(len, self.data.values()))


class FileCacheManager:
def encode(key):
return pickle.dumps(key, protocol=pickle.HIGHEST_PROTOCOL) # nosemgrep


def resolve_path(path):
return os.path.abspath(
os.path.expanduser(
os.path.expandvars(path)))


class SqlKvCache(Cache):

create_table = """
create table if not exists c7n_cache (
key blob primary key,
value blob,
create_date timestamp
)
"""

def __init__(self, config):
self.config = config
super().__init__(config)
self.cache_period = config.cache_period
self.cache_path = os.path.abspath(
os.path.expanduser(
os.path.expandvars(
config.cache)))
self.data = {}

def get(self, key):
k = pickle.dumps(key) # nosemgrep
return self.data.get(k)
self.cache_path = resolve_path(config.cache)
self.conn = None

def haskey(self, key):
return pickle.dumps(key) in self.data # nosemgrep

def load(self):
if self.data:
return True
if os.path.isfile(self.cache_path):
if (time.time() - os.stat(self.cache_path).st_mtime >
self.config.cache_period * 60):
return False
# migration from pickle cache file
if os.path.exists(self.cache_path):
with open(self.cache_path, 'rb') as fh:
try:
self.data = pickle.load(fh) # nosec nosemgrep
except EOFError:
return False
log.debug("Using cache file %s" % self.cache_path)
return True
header = fh.read(15)
if header != b'SQLite format 3':
log.debug('removing old cache file')
os.remove(self.cache_path)
elif not os.path.exists(os.path.dirname(self.cache_path)):
# parent directory creation
os.makedirs(os.path.dirname(self.cache_path))
self.conn = sqlite3.connect(self.cache_path)
self.conn.execute(self.create_table)
with self.conn as cursor:
log.debug('expiring stale cache entries')
cursor.execute(
'delete from c7n_cache where create_date < ?',
[datetime.utcnow() - timedelta(minutes=self.cache_period)])
return True

def save(self, key, data):
try:
with open(self.cache_path, 'wb') as fh: # nosec
self.data[pickle.dumps(key)] = data # nosemgrep
pickle.dump(self.data, fh, protocol=2) # nosemgrep
except Exception as e:
log.warning("Could not save cache %s err: %s" % (
self.cache_path, e))
if not os.path.exists(self.cache_path):
directory = os.path.dirname(self.cache_path)
log.info('Generating Cache directory: %s.' % directory)
try:
os.makedirs(directory)
except Exception as e:
log.warning("Could not create directory: %s err: %s" % (
directory, e))
def get(self, key):
with self.conn as cursor:
r = cursor.execute(
'select value, create_date from c7n_cache where key = ?',
[sqlite3.Binary(encode(key))]
)
row = r.fetchone()
if row is None:
return None
value, create_date = row
create_date = sqlite3.converters['TIMESTAMP'](create_date.encode('utf8'))
if (datetime.utcnow() - create_date).total_seconds() / 60.0 > self.cache_period:
return None
return pickle.loads(value) # nosec nosemgrep

def save(self, key, data, timestamp=None):
with self.conn as cursor:
timestamp = timestamp or datetime.utcnow()
cursor.execute(
'replace into c7n_cache (key, value, create_date) values (?, ?, ?)',
(sqlite3.Binary(encode(key)), sqlite3.Binary(encode(data)), timestamp))

def size(self):
return os.path.exists(self.cache_path) and os.path.getsize(self.cache_path) or 0

def close(self):
if self.conn:
self.conn.close()
self.conn = None
5 changes: 3 additions & 2 deletions c7n/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import logging
import os
import sys
from typing import List

import yaml
from yaml.constructor import ConstructorError
Expand Down Expand Up @@ -283,7 +284,7 @@ def validate(options):


@policy_command
def run(options, policies):
def run(options, policies: List[Policy]) -> None:
exit_code = 0

# AWS - Sanity check that we have an assumable role before executing policies
Expand All @@ -295,7 +296,7 @@ def run(options, policies):
log.exception("Unable to assume role %s", options.assume_role)
sys.exit(1)

errored_policies = []
errored_policies: List[str] = []
for policy in policies:
try:
policy()
Expand Down
14 changes: 7 additions & 7 deletions c7n/policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import logging
import os
import time
from typing import List

from dateutil import parser, tz as tzutil
import jmespath
Expand Down Expand Up @@ -68,12 +69,12 @@ class PolicyCollection:

log = logging.getLogger('c7n.policies')

def __init__(self, policies, options):
def __init__(self, policies: 'List[Policy]', options):
self.options = options
self.policies = policies

@classmethod
def from_data(cls, data, options, session_factory=None):
def from_data(cls, data: dict, options, session_factory=None):
# session factory param introduction needs an audit and review
# on tests.
sf = session_factory if session_factory else cls.session_factory()
Expand Down Expand Up @@ -1111,15 +1112,15 @@ def __repr__(self):
self.resource_type, self.name, self.options.region)

@property
def name(self):
def name(self) -> str:
return self.data['name']

@property
def resource_type(self):
def resource_type(self) -> str:
return self.data['resource']

@property
def provider_name(self):
def provider_name(self) -> str:
if '.' in self.resource_type:
provider_name, resource_type = self.resource_type.split('.', 1)
else:
Expand Down Expand Up @@ -1299,8 +1300,7 @@ def __call__(self):
resources = mode.provision()
else:
resources = mode.run()
# clear out resource manager post run, to clear cache
self.resource_manager = self.load_resource_manager()

return resources

run = __call__
Expand Down
6 changes: 5 additions & 1 deletion c7n/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import functools
import itertools
import json
from typing import List

import jmespath
import os
Expand Down Expand Up @@ -503,7 +504,7 @@ def get_cache_key(self, query):
'q': query
}

def resources(self, query=None, augment=True):
def resources(self, query=None, augment=True) -> List[dict]:
query = self.source.get_query_params(query)
cache_key = self.get_cache_key(query)
resources = None
Expand All @@ -527,6 +528,8 @@ def resources(self, query=None, augment=True):
# Don't pollute cache with unaugmented resources.
self._cache.save(cache_key, resources)

self._cache.close()

resource_count = len(resources)
with self.ctx.tracer.subsegment('filter'):
resources = self.filter_resources(resources)
Expand Down Expand Up @@ -555,6 +558,7 @@ def _get_cached_resources(self, ids):
m = self.get_model()
id_set = set(ids)
return [r for r in resources if r[m.id] in id_set]
self._cache.close()
return None

def get_resources(self, ids, cache=True, augment=True):
Expand Down
59 changes: 59 additions & 0 deletions c7n/resources/ec2.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,10 @@
import random
import re
import zlib
from typing import List
from distutils.version import LooseVersion

import botocore
from botocore.exceptions import ClientError
from dateutil.parser import parse
from concurrent.futures import as_completed
Expand Down Expand Up @@ -295,6 +298,62 @@ def __call__(self, i):
return self.operator(map(self.match, volumes))


@filters.register('stop-protected')
class DisableApiStop(Filter):
"""EC2 instances with ``disableApiStop`` attribute set
Filters EC2 instances with ``disableApiStop`` attribute set to true.
:Example:
.. code-block:: yaml
policies:
- name: stop-protection-enabled
resource: ec2
filters:
- type: stop-protected
:Example:
.. code-block:: yaml
policies:
- name: stop-protection-NOT-enabled
resource: ec2
filters:
- not:
- type: stop-protected
"""

schema = type_schema('stop-protected')
permissions = ('ec2:DescribeInstanceAttribute',)

def process(self, resources: List[dict], event=None) -> List[dict]:
client = utils.local_session(
self.manager.session_factory).client('ec2')
return [r for r in resources
if self._is_stop_protection_enabled(client, r)]

def _is_stop_protection_enabled(self, client, instance: dict) -> bool:
attr_val = self.manager.retry(
client.describe_instance_attribute,
Attribute='disableApiStop',
InstanceId=instance['InstanceId']
)
return attr_val['DisableApiStop']['Value']

def validate(self) -> None:
botocore_min_version = '1.26.7'

if LooseVersion(botocore.__version__) < LooseVersion(botocore_min_version):
raise PolicyValidationError(
"'stop-protected' filter requires botocore version "
f'{botocore_min_version} or above. '
f'Installed version is {botocore.__version__}.'
)


@filters.register('termination-protected')
class DisableApiTermination(Filter):
"""EC2 instances with ``disableApiTermination`` attribute set
Expand Down
Loading

0 comments on commit 39aec71

Please sign in to comment.