Skip to content

Commit

Permalink
Refactor Core's Writers (aws#190)
Browse files Browse the repository at this point in the history
* Cherry picking change of refactor writers

* set default step to 0

* remove histogram related stuff

* rename IndexUtil

* Fix imports

* remove import of re

* Fix step usage by event file writer

* Fix import errors

* Fix core test

* undo utils change worker pid

* fix import

* fix import

* do not flush index writer

* review comments
  • Loading branch information
rahul003 authored Sep 17, 2019
1 parent 5c637c2 commit cb9c20e
Show file tree
Hide file tree
Showing 28 changed files with 259 additions and 267 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@


def compile_summary_protobuf():
proto_path = 'tornasole/core/tfevent'
proto_path = 'tornasole/core/tfevent/proto'
proto_files = os.path.join(proto_path, '*.proto')
cmd = 'protoc ' + proto_files + ' --python_out=.'
print('compiling protobuf files in {}'.format(proto_path))
Expand Down
10 changes: 4 additions & 6 deletions tests/core/test_index.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
from tornasole.core.writer import FileWriter
from tornasole.core.tfevent.event_file_writer import *
from tornasole.core.reader import FileReader
from tornasole.core.tfevent.util import EventFileLocation
from tornasole.core.indexutils import *
from tornasole.core.locations import EventFileLocation, IndexFileLocationUtils
import shutil
import os
import numpy as np
import json

def test_index():
Expand All @@ -23,11 +22,10 @@ def test_index():
writer.flush()
writer.close()
efl = EventFileLocation(step_num=step, worker_name=worker)
eventfile = efl.get_location(run_dir=run_dir)
indexfile = IndexUtil.get_index_key_for_step(run_dir, step,worker)
eventfile = efl.get_location(trial_dir=run_dir)
indexfile = IndexFileLocationUtils.get_index_key_for_step(run_dir, step, worker)

fo = open(eventfile, "rb")

with open(indexfile) as idx_file:
index_data = json.load(idx_file)
tensor_payload = index_data['tensor_payload']
Expand Down
2 changes: 1 addition & 1 deletion tests/tensorflow/hooks/test_write.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from tornasole.tensorflow import reset_collections
from .utils import *
from tornasole.core.reader import FileReader
from tornasole.core.tfevent.util import EventFileLocation
from tornasole.core.locations import EventFileLocation
from tornasole.core.json_config import TORNASOLE_CONFIG_FILE_PATH_ENV_STR
import tornasole.tensorflow as ts

Expand Down
12 changes: 6 additions & 6 deletions tornasole/core/index_reader.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import os
import json
from tornasole.core.indexutils import TensorLocation, IndexUtil
from tornasole.core.locations import TensorLocation, IndexFileLocationUtils
from tornasole.core.s3_utils import list_s3_objects
from tornasole.core.access_layer.s3handler import ReadObjectRequest, S3Handler
from tornasole.core.utils import get_logger, is_s3, list_files_in_directory, step_in_range
Expand Down Expand Up @@ -34,7 +34,7 @@ def get_s3_responses(bucket_name, prefix_name, start_after_key, range_steps=None
index_files, last_index_token = S3IndexReader.list_all_index_files_from_s3(bucket_name, prefix_name,
start_after_key)
for index_file in index_files:
step = IndexUtil.parse_step_from_index_file_name(index_file)
step = IndexFileLocationUtils.parse_step_from_index_file_name(index_file)
if (range_steps is not None and step_in_range(range_steps, step)) or \
range_steps is None:
steps.append(step)
Expand All @@ -46,7 +46,7 @@ def get_s3_responses(bucket_name, prefix_name, start_after_key, range_steps=None
@staticmethod
def list_all_index_files_from_s3(bucket_name, prefix_name, start_after_key=None):
index_files, last_index_token = list_s3_objects(bucket_name,
IndexUtil.get_index_path(prefix_name),
IndexFileLocationUtils.get_index_path(prefix_name),
start_after_key)

return index_files, last_index_token
Expand All @@ -56,7 +56,7 @@ class LocalIndexReader:

@staticmethod
def list_index_files_in_dir(dirname):
index_dirname = IndexUtil.get_index_path(dirname)
index_dirname = IndexFileLocationUtils.get_index_path(dirname)
index_files = list_files_in_directory(index_dirname)
return sorted(index_files)

Expand All @@ -67,10 +67,10 @@ def get_disk_responses(path, start_after_key=0, range_steps=None):
responses = []
index_files = index_files[start_after_key:] # ignore files we have already read
for index_file in index_files:
step = IndexUtil.parse_step_from_index_file_name(index_file)
step = IndexFileLocationUtils.parse_step_from_index_file_name(index_file)
if (range_steps is not None and step_in_range(range_steps, step)) or \
range_steps is None:
steps.append(IndexUtil.parse_step_from_index_file_name(index_file))
steps.append(IndexFileLocationUtils.parse_step_from_index_file_name(index_file))
with open(index_file) as f:
responses.append(f.read().encode())
start_after_key += len(index_files) # Last file that we have read
Expand Down
63 changes: 57 additions & 6 deletions tornasole/core/indexutils.py → tornasole/core/locations.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,9 @@
import os
import re

from .utils import get_immediate_subdirectories, get_logger

logger = get_logger()


class TensorLocation:
Expand All @@ -18,18 +23,65 @@ def to_dict(self):
}


class IndexUtil:
STEP_NUMBER_FORMATTING_LENGTH = '012'


class EventFileLocation:
def __init__(self, step_num, worker_name, type='events'):
self.step_num = int(step_num)
self.worker_name = worker_name
self.type = type

def get_location(self, trial_dir=''):
step_num_str = str(format(self.step_num, STEP_NUMBER_FORMATTING_LENGTH))
event_filename = f"{step_num_str}_{self.worker_name}.tfevents"
if trial_dir:
event_key_prefix = os.path.join(trial_dir, self.type)
else:
event_key_prefix = self.type
return os.path.join(event_key_prefix, step_num_str, event_filename)

@staticmethod
def match_regex(s):
return EventFileLocation.load_filename(s, print_error=False)

@staticmethod
def load_filename(s, print_error=True):
event_file_name = os.path.basename(s)
m = re.search('(.*)_(.*).tfevents$', event_file_name)
if m:
step_num = int(m.group(1))
worker_name = m.group(2)
return EventFileLocation(step_num=step_num, worker_name=worker_name)
else:
if print_error:
logger.error('Failed to load efl: ', s)
return None

@staticmethod
def get_step_dirs(trial_dir):
return get_immediate_subdirectories(os.path.join(trial_dir,
'events'))

@staticmethod
def get_step_dir_path(trial_dir, step_num):
step_num = int(step_num)
return os.path.join(trial_dir, 'events',
format(step_num, STEP_NUMBER_FORMATTING_LENGTH))


class IndexFileLocationUtils:
# These functions are common to index reader and index writer
MAX_INDEX_FILE_NUM_IN_INDEX_PREFIX = 1000

@staticmethod
def get_index_prefix_for_step(step_num):
index_prefix_for_step = step_num // IndexUtil.MAX_INDEX_FILE_NUM_IN_INDEX_PREFIX
index_prefix_for_step = step_num // IndexFileLocationUtils.MAX_INDEX_FILE_NUM_IN_INDEX_PREFIX
return format(index_prefix_for_step, '09')

@staticmethod
def next_index_prefix_for_step(step_num):
index_prefix_for_step = step_num // IndexUtil.MAX_INDEX_FILE_NUM_IN_INDEX_PREFIX
index_prefix_for_step = step_num // IndexFileLocationUtils.MAX_INDEX_FILE_NUM_IN_INDEX_PREFIX
return format(index_prefix_for_step + 1, '09')

@staticmethod
Expand All @@ -42,8 +94,8 @@ def indexS3Key(trial_prefix, index_prefix_for_step_str, step_num, worker_name):
# for a step_num index files lies in prefix step_num/MAX_INDEX_FILE_NUM_IN_INDEX_PREFIX
@staticmethod
def get_index_key_for_step(trial_prefix, step_num, worker_name):
index_prefix_for_step_str = IndexUtil.get_index_prefix_for_step(step_num)
return IndexUtil.indexS3Key(trial_prefix, index_prefix_for_step_str, step_num, worker_name)
index_prefix_for_step_str = IndexFileLocationUtils.get_index_prefix_for_step(step_num)
return IndexFileLocationUtils.indexS3Key(trial_prefix, index_prefix_for_step_str, step_num, worker_name)
# let's assume worker_name is given by hook
# We need to think on naming conventions and access patterns for:
# 1) muti-node training --> data parallel
Expand All @@ -66,4 +118,3 @@ def parse_step_from_index_file_name(index_file_name):
@staticmethod
def get_index_path(path):
return os.path.join(path, 'index')

4 changes: 2 additions & 2 deletions tornasole/core/tfevent/event_file_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,11 @@

"""Reads events from disk."""

import tornasole.core.tfevent.types_pb2 as types_pb2
import tornasole.core.tfevent.proto.types_pb2 as types_pb2
import logging
import numpy as np

from .event_pb2 import Event
from .proto.event_pb2 import Event

from tornasole.core.tfrecord.record_reader import RecordReader
from tornasole.core.modes import ModeKeys, MODE_STEP_PLUGIN_NAME, MODE_PLUGIN_NAME
Expand Down
Loading

0 comments on commit cb9c20e

Please sign in to comment.