From 7ceb60412e50454026b77350d6926072a8241da8 Mon Sep 17 00:00:00 2001 From: Andy Ayers Date: Sat, 10 Feb 2024 09:12:54 -0800 Subject: [PATCH] RLCSE: fix MCMC and GatherFeatures, overwrite dumps (#395) The JIT will append dumps to existing files, so using RLCSE to save dumps was creating large files (each sequence's dump is periodically updated to show the impact of the current parameters). Also I was going to make MCMC and such subcommands, but changed my mind, and forgot to hook the options back up. --- src/jit-rl-cse/MLCSE.cs | 17 +++++++++++++---- src/jit-rl-cse/MLCSECommands.cs | 3 +++ 2 files changed, 16 insertions(+), 4 deletions(-) diff --git a/src/jit-rl-cse/MLCSE.cs b/src/jit-rl-cse/MLCSE.cs index ed9261b..4c2b14b 100644 --- a/src/jit-rl-cse/MLCSE.cs +++ b/src/jit-rl-cse/MLCSE.cs @@ -376,7 +376,7 @@ static void PolicyGradient(IEnumerable methods) // number of times we cycle through the methods int nRounds = Get(s_commands.NumberOfRounds); // how many trials per method each cycle (minibatch) - int nIter =Get(s_commands.MinibatchSize); + int nIter = Get(s_commands.MinibatchSize); // how often to show results bool showEvery = Get(s_commands.ShowRounds); uint showEveryInterval = Get(s_commands.ShowRoundsInterval); @@ -622,10 +622,20 @@ static void PolicyGradient(IEnumerable methods) QVDumpDot(method, s); } - // Dump dasm/dump if we don't have one already + // Write out dasm/dump for method with this sequence, and baseline. + // Overwrite method dumps every so often, so we see fresh likelihood computations. + // Dasm and baselines should not change so initial ones are fine. // + bool shouldOverwriteDump = (r > 0) && (summaryInterval > 0) && (r % (4 * summaryInterval) == summaryInterval); + string cleanSequence = updateSequence.Replace(',', '_'); string dumpFile = Path.Combine(dumpDir, $"dump-{method.spmiIndex}-{cleanSequence}.d"); + + if (shouldOverwriteDump && File.Exists(dumpFile)) + { + File.Delete(dumpFile); + } + if (!File.Exists(dumpFile)) { List dumpOptions = new List(updateOptions); @@ -801,7 +811,6 @@ static void PolicyGradient(IEnumerable methods) Console.Write($" B:{MetricsParser.GetBaseLikelihoods(batchRuns[lastValidRun]),-60}"); } } - Console.Write(batchDetails[lastValidRun]); Console.ResetColor(); } } @@ -1022,7 +1031,7 @@ static void MCMC(IEnumerable methods) // Show each method's summary bool showEachCase = Get(s_commands.ShowEachMethod); // show each particular trial result - bool showEachRun = Get(s_commands.ShowEachRun); + bool showEachRun = Get(s_commands.ShowEachMCMCRun); // Show the Markov Chain bool showMC = Get(s_commands.ShowMarkovChain); // Draw the Markov Chain (tree) diff --git a/src/jit-rl-cse/MLCSECommands.cs b/src/jit-rl-cse/MLCSECommands.cs index 90166a4..69d3b7e 100644 --- a/src/jit-rl-cse/MLCSECommands.cs +++ b/src/jit-rl-cse/MLCSECommands.cs @@ -133,6 +133,9 @@ public MLCSECommands(string[] args) : base("Use ML to explore JIT CSE Heuristics Options.Add(UseSpecificMethods); Options.Add(UseAdditionalMethods); + Options.Add(GatherFeatures); + + Options.Add(DoMCMC); Options.Add(RememberMCMC); Options.Add(ShowEachMethod); Options.Add(ShowEachMCMCRun);