diff --git a/tensorboard/plugins/custom_scalar/BUILD b/tensorboard/plugins/custom_scalar/BUILD index b6177be636..723f8791e1 100644 --- a/tensorboard/plugins/custom_scalar/BUILD +++ b/tensorboard/plugins/custom_scalar/BUILD @@ -78,11 +78,13 @@ py_test( ":summary", "//tensorboard:expect_numpy_installed", "//tensorboard:expect_tensorflow_installed", + "//tensorboard/backend:application", "//tensorboard/backend/event_processing:event_multiplexer", "//tensorboard/plugins:base_plugin", "//tensorboard/plugins/scalar:scalars_plugin", "//tensorboard/plugins/scalar:summary", "//tensorboard/util:test_util", + "@org_pocoo_werkzeug", ], ) @@ -99,12 +101,14 @@ py_test( ":summary", "//tensorboard:expect_numpy_installed", "//tensorboard:expect_tensorflow_installed", + "//tensorboard/backend:application", "//tensorboard/backend/event_processing:event_multiplexer", "//tensorboard/compat:no_tensorflow", "//tensorboard/plugins:base_plugin", "//tensorboard/plugins/scalar:scalars_plugin", "//tensorboard/plugins/scalar:summary", "//tensorboard/util:test_util", + "@org_pocoo_werkzeug", ], ) diff --git a/tensorboard/plugins/custom_scalar/custom_scalars_plugin.py b/tensorboard/plugins/custom_scalar/custom_scalars_plugin.py index 8db3619500..2deefdc268 100644 --- a/tensorboard/plugins/custom_scalar/custom_scalars_plugin.py +++ b/tensorboard/plugins/custom_scalar/custom_scalars_plugin.py @@ -121,9 +121,12 @@ def frontend_metadata(self): def download_data_route(self, request): run = request.args.get("run") tag = request.args.get("tag") + experiment = plugin_util.experiment_id(request.environ) response_format = request.args.get("format") try: - body, mime_type = self.download_data_impl(run, tag, response_format) + body, mime_type = self.download_data_impl( + run, tag, experiment, response_format + ) except ValueError as e: return http_util.Respond( request=request, diff --git a/tensorboard/plugins/custom_scalar/custom_scalars_plugin_test.py b/tensorboard/plugins/custom_scalar/custom_scalars_plugin_test.py index 15c74afb25..3807e575d2 100644 --- a/tensorboard/plugins/custom_scalar/custom_scalars_plugin_test.py +++ b/tensorboard/plugins/custom_scalar/custom_scalars_plugin_test.py @@ -19,12 +19,18 @@ from __future__ import division from __future__ import print_function +import csv +import io +import json import os import numpy as np import tensorflow as tf +from werkzeug import test as werkzeug_test +from werkzeug import wrappers from google.protobuf import json_format +from tensorboard.backend import application from tensorboard.backend.event_processing import ( plugin_event_multiplexer as event_multiplexer, ) @@ -176,11 +182,16 @@ def createPlugin(self, logdir): ] = plugin_instance return custom_scalars_plugin_instance - def testDownloadData(self): - body, mime_type = self.plugin.download_data_impl( - "foo", "squares/scalar_summary", "exp_id", "json" + def test_download_url_json(self): + wsgi_app = application.TensorBoardWSGI([self.plugin]) + server = werkzeug_test.Client(wsgi_app, wrappers.BaseResponse) + response = server.get( + "/data/plugin/custom_scalars/download_data?run=%s&tag=%s" + % ("foo", "squares/scalar_summary") ) - self.assertEqual("application/json", mime_type) + self.assertEqual(200, response.status_code) + self.assertEqual("application/json", response.headers["Content-Type"]) + body = json.loads(response.get_data()) self.assertEqual(4, len(body)) for step, entry in enumerate(body): # The time stamp should be reasonable. @@ -188,6 +199,23 @@ def testDownloadData(self): self.assertEqual(step, entry[1]) np.testing.assert_allclose(step * step, entry[2]) + def test_download_url_csv(self): + wsgi_app = application.TensorBoardWSGI([self.plugin]) + server = werkzeug_test.Client(wsgi_app, wrappers.BaseResponse) + response = server.get( + "/data/plugin/custom_scalars/download_data?run=%s&tag=%s&format=csv" + % ("foo", "squares/scalar_summary") + ) + self.assertEqual(200, response.status_code) + self.assertEqual( + "text/csv; charset=utf-8", response.headers["Content-Type"] + ) + payload = response.get_data() + s = io.StringIO(payload.decode("utf-8")) + reader = csv.reader(s) + self.assertEqual(["Wall time", "Step", "Value"], next(reader)) + self.assertEqual(len(list(reader)), 4) + def testScalars(self): body = self.plugin.scalars_impl("bar", "increments", "exp_id") self.assertTrue(body["regex_valid"])