From b6bec9963c70b21fa0ea2a62c07d964d5231fbf3 Mon Sep 17 00:00:00 2001 From: Tim Armstrong Date: Mon, 8 Aug 2016 18:14:30 -0700 Subject: [PATCH] Codegen caching PoC --- be/src/codegen/llvm-codegen.cc | 101 +++++++++++++++++++++++++++++++---------- be/src/codegen/llvm-codegen.h | 5 +- 2 files changed, 82 insertions(+), 24 deletions(-) diff --git a/be/src/codegen/llvm-codegen.cc b/be/src/codegen/llvm-codegen.cc index fb038cd..da5b120 100644 --- a/be/src/codegen/llvm-codegen.cc +++ b/be/src/codegen/llvm-codegen.cc @@ -65,6 +65,7 @@ #include "util/hdfs-util.h" #include "util/path-builder.h" #include "util/runtime-profile-counters.h" +#include "util/spinlock.h" #include "util/test-info.h" #include "common/names.h" @@ -98,6 +99,14 @@ string LlvmCodeGen::cpu_name_; vector LlvmCodeGen::cpu_attrs_; unordered_set LlvmCodeGen::gv_ref_ir_fns_; +struct CodegenCacheEntry { + unique_ptr context; + unique_ptr execution_engine; +}; + +SpinLock codegen_cache_lock; +unordered_map codegen_cache; + static void LlvmCodegenHandleError(void* user_data, const std::string& reason, bool gen_crash_diag) { LOG(FATAL) << "LLVM hit fatal error: " << reason.c_str(); @@ -490,6 +499,7 @@ void LlvmCodeGen::SetupJITListeners() { LlvmCodeGen::~LlvmCodeGen() { // Execution engine executes callback on event listener, so tear down engine first. + // TODO: need to solve this problem. execution_engine_.reset(); symbol_emitter_.reset(); } @@ -915,6 +925,44 @@ Status LlvmCodeGen::FinalizeModule() { if (fns_to_jit_compile_.empty()) return Status::OK(); RETURN_IF_ERROR(FinalizeLazyMaterialization()); + PruneModule(); + // Update counters before final optimization, but after removing unused functions. This + // gives us a rough measure of how much work the optimization and compilation must do. + InstructionCounter counter; + counter.visit(*module_); + COUNTER_SET(num_functions_, counter.GetCount(InstructionCounter::TOTAL_FUNCTIONS)); + COUNTER_SET(num_instructions_, counter.GetCount(InstructionCounter::TOTAL_INSTS)); + + + string bitcode; + raw_string_ostream bitcode_stream(bitcode); + + MonotonicStopWatch sw; + sw.Start(); + WriteBitcodeToFile(module_, bitcode_stream); + sw.Stop(); + bitcode_stream.flush(); + LOG(INFO) << "Wrote to string: " << bitcode.size() << " in " << (sw.ElapsedTime() / 1000) << "us"; + { + lock_guard l(codegen_cache_lock); + auto it = codegen_cache.find(bitcode); + if (it != codegen_cache.end()) { + LOG(INFO) << "CACHE HIT"; + // Get pointers to all codegen'd functions. TODO: combine logic with the regular path + for (int i = 0; i < fns_to_jit_compile_.size(); ++i) { + Function* function = fns_to_jit_compile_[i].first; + LOG(INFO) << "Lookup " << function->getName().data(); + void* jitted_function = reinterpret_cast( + it->second.execution_engine->getFunctionAddress(function->getName())); + DCHECK(jitted_function != NULL) << "Failed to jit " << function->getName().data(); + *fns_to_jit_compile_[i].second = jitted_function; + LOG(INFO) << "Got " << function->getName().data() << " ok"; + } + return Status::OK(); + } + } + LOG(INFO) << "CACHE MISS"; + if (optimizations_enabled_ && !FLAGS_disable_optimization_passes) OptimizeModule(); if (FLAGS_opt_module_dir.size() != 0) { @@ -941,9 +989,38 @@ Status LlvmCodeGen::FinalizeModule() { DCHECK(jitted_function != NULL) << "Failed to jit " << function->getName().data(); *fns_to_jit_compile_[i].second = jitted_function; } + + CodegenCacheEntry cache_entry; + cache_entry.context = std::move(context_); + cache_entry.execution_engine = std::move(execution_engine_); + { + lock_guard l(codegen_cache_lock); + codegen_cache.insert({std::move(bitcode), std::move(cache_entry)}); + } return Status::OK(); } +void LlvmCodeGen::PruneModule() { + SCOPED_TIMER(optimization_timer_); + TargetIRAnalysis target_analysis = + execution_engine_->getTargetMachine()->getTargetIRAnalysis(); + + // Before running any other optimization passes, run the internalize pass, giving it + // the names of all functions registered by AddFunctionToJit(), followed by the + // global dead code elimination pass. This causes all functions not registered to be + // JIT'd to be marked as internal, and any internal functions that are not used are + // deleted by DCE pass. This greatly decreases compile time by removing unused code. + vector exported_fn_names; + for (int i = 0; i < fns_to_jit_compile_.size(); ++i) { + exported_fn_names.push_back(fns_to_jit_compile_[i].first->getName().data()); + } + unique_ptr module_pass_manager(new legacy::PassManager()); + module_pass_manager->add(createTargetTransformInfoWrapperPass(target_analysis)); + module_pass_manager->add(createInternalizePass(exported_fn_names)); + module_pass_manager->add(createGlobalDCEPass()); + module_pass_manager->run(*module_); +} + void LlvmCodeGen::OptimizeModule() { SCOPED_TIMER(optimization_timer_); @@ -966,28 +1043,6 @@ void LlvmCodeGen::OptimizeModule() { TargetIRAnalysis target_analysis = execution_engine_->getTargetMachine()->getTargetIRAnalysis(); - // Before running any other optimization passes, run the internalize pass, giving it - // the names of all functions registered by AddFunctionToJit(), followed by the - // global dead code elimination pass. This causes all functions not registered to be - // JIT'd to be marked as internal, and any internal functions that are not used are - // deleted by DCE pass. This greatly decreases compile time by removing unused code. - vector exported_fn_names; - for (int i = 0; i < fns_to_jit_compile_.size(); ++i) { - exported_fn_names.push_back(fns_to_jit_compile_[i].first->getName().data()); - } - unique_ptr module_pass_manager(new legacy::PassManager()); - module_pass_manager->add(createTargetTransformInfoWrapperPass(target_analysis)); - module_pass_manager->add(createInternalizePass(exported_fn_names)); - module_pass_manager->add(createGlobalDCEPass()); - module_pass_manager->run(*module_); - - // Update counters before final optimization, but after removing unused functions. This - // gives us a rough measure of how much work the optimization and compilation must do. - InstructionCounter counter; - counter.visit(*module_); - COUNTER_SET(num_functions_, counter.GetCount(InstructionCounter::TOTAL_FUNCTIONS)); - COUNTER_SET(num_instructions_, counter.GetCount(InstructionCounter::TOTAL_INSTS)); - // Create and run function pass manager unique_ptr fn_pass_manager( new legacy::FunctionPassManager(module_)); @@ -1000,7 +1055,7 @@ void LlvmCodeGen::OptimizeModule() { fn_pass_manager->doFinalization(); // Create and run module pass manager - module_pass_manager.reset(new legacy::PassManager()); + unique_ptr module_pass_manager(new legacy::PassManager()); module_pass_manager->add(createTargetTransformInfoWrapperPass(target_analysis)); pass_builder.populateModulePassManager(*module_pass_manager); module_pass_manager->run(*module_); diff --git a/be/src/codegen/llvm-codegen.h b/be/src/codegen/llvm-codegen.h index fa465fe..7093ef8 100644 --- a/be/src/codegen/llvm-codegen.h +++ b/be/src/codegen/llvm-codegen.h @@ -496,7 +496,10 @@ class LlvmCodeGen { // Used for testing. void ResetVerification() { is_corrupt_ = false; } - /// Optimizes the module. This includes pruning the module of any unused functions. + // Prunes any unused functions from the module. + void PruneModule(); + + /// Optimizes the module using LLVM's optimizer. void OptimizeModule(); /// Clears generated hash fns. This is only used for testing. -- 2.5.0