From 603ae26825cb9653398bd55c24a994253f41eaba Mon Sep 17 00:00:00 2001
From: Tianyi Wang <tianyi@apache.org>
Date: Sat, 9 Jun 2018 16:51:11 -0700
Subject: [PATCH] WIP: IMPALA-3816,IMPALA-4065: full TupleRowComparator codegen

This patch removes the indirection of codegened TupleRowComparator::
Compare() at its call sites in sorter and topn node. It's implemented by
cloning and modifying all the functions between the codegened entry
function and Compare().

The call site in SortedRunMerger is still indirect, but the indirection
could be removed in the same way as this patch.

Testing:
* TODO: unit tests for SCC algo

Perf:
Tianyi said: "TPCH queries with a sort node are 2%-5% faster."
TODO: run benchmarks

Change-Id: I77f1b88a7f8ac9976f143fbea7be0c6fcb01ff07
---
 be/src/codegen/gen_ir_descriptions.py              |   4 +
 be/src/codegen/llvm-codegen.cc                     | 154 +++++++++++++++
 be/src/codegen/llvm-codegen.h                      |  18 ++
 be/src/exec/exchange-node.cc                       |   8 +-
 be/src/exec/topn-node.cc                           |  81 ++++----
 be/src/runtime/data-stream-test.cc                 |   3 +-
 be/src/runtime/sorted-run-merger.cc                |   4 +-
 be/src/runtime/sorter-internal.h                   |   5 +
 be/src/runtime/sorter-ir.cc                        |   2 +
 be/src/runtime/sorter.cc                           |  37 +++-
 be/src/util/tuple-row-compare.cc                   | 213 +++++++++------------
 be/src/util/tuple-row-compare.h                    |  79 ++++----
 .../test_replace_tuple_row_compare.py              |  52 +++++
 13 files changed, 444 insertions(+), 216 deletions(-)
 create mode 100644 tests/custom_cluster/test_replace_tuple_row_compare.py

diff --git a/be/src/codegen/gen_ir_descriptions.py b/be/src/codegen/gen_ir_descriptions.py
index 4829539..824e89d 100755
--- a/be/src/codegen/gen_ir_descriptions.py
+++ b/be/src/codegen/gen_ir_descriptions.py
@@ -202,6 +202,10 @@ ir_functions = [
    "_ZN6impala8RawValue20GetHashValueFastHashEPKvRKNS_10ColumnTypeEm"],
   ["TOPN_NODE_INSERT_BATCH",
    "_ZN6impala8TopNNode11InsertBatchEPNS_8RowBatchE"],
+  ["SORTER_SORTHELPER",
+   "_ZN6impala6Sorter11TupleSorter10SortHelperENS0_13TupleIteratorES2_"],
+  ["COMPARE_INTERPRETED",
+   "_ZNK6impala18TupleRowComparator18CompareInterpretedEPKNS_8TupleRowES3_"],
   ["MEMPOOL_ALLOCATE",
    "_ZN6impala7MemPool8AllocateILb0EEEPhli"],
   ["MEMPOOL_CHECKED_ALLOCATE",
diff --git a/be/src/codegen/llvm-codegen.cc b/be/src/codegen/llvm-codegen.cc
index d6f89a3..7bac776 100644
--- a/be/src/codegen/llvm-codegen.cc
+++ b/be/src/codegen/llvm-codegen.cc
@@ -928,6 +928,160 @@ int LlvmCodeGen::ReplaceCallSites(
   return replaced;
 }
 
+// If 'inst' is a call or invoke instruction with a known callee (i.e. a direct call),
+// then return the callee. Otherwise return nullptr. The returned callee may be an
+// external function, i.e. not defined in the IR.
+static llvm::Function* GetCallee(llvm::Instruction* inst) {
+  llvm::Function* callee = nullptr;
+  if (llvm::isa<llvm::CallInst>(inst)) {
+    callee = reinterpret_cast<llvm::CallInst*>(inst)->getCalledFunction();
+  } else if (llvm::isa<llvm::InvokeInst>(inst)) {
+    callee = reinterpret_cast<llvm::InvokeInst*>(inst)->getCalledFunction();
+  }
+  return callee;
+}
+
+// Find the next callee at '*iter' or after it. Callees are returned according to
+// the same rules as GetCallee(). If no callee can be found, return nullptr.
+// 'iter' is advanced to the instruction after found call instruction.
+static llvm::Function* FindNextCallee(llvm::inst_iterator* iter) {
+    llvm::Function* callee = nullptr;
+  for (; callee == nullptr && !iter->atEnd(); ++(*iter)) {
+    callee = GetCallee(&**iter);
+  }
+  return callee;
+}
+
+void LlvmCodeGen::ReplaceCallSitesRecursivelyHelper(llvm::Function* old_fn,
+    llvm::Function* new_fn, const vector<llvm::Function*>& fns_to_clone,
+    llvm::ValueMap<const llvm::Value*, llvm::WeakTrackingVH>* cloned_fn_map) {
+  // Do a pass over functions to be cloned to build mapping from the old to new functions.
+  vector<llvm::Function*> cloned_fns;
+  cloned_fn_map->insert({old_fn, llvm::WeakTrackingVH(new_fn)});
+  for (llvm::Function* f : fns_to_clone) {
+    cloned_fns.push_back(CloneFunction(f));
+    cloned_fn_map->insert({f, llvm::WeakTrackingVH(cloned_fns.back())});
+  }
+  // Second pass replaces all of the functions in the cloned functions.
+  for (llvm::Function* f : cloned_fns) {
+    RemapFunction(*f, *cloned_fn_map, llvm::RF_IgnoreMissingLocals);
+  }
+}
+
+llvm::Function* LlvmCodeGen::ReplaceCallSitesRecursively(
+    llvm::Function* caller, llvm::Function* old_fn, llvm::Function* new_fn) {
+  // This function finds all the functions between 'caller' and 'old_fn' and make clones
+  // of them. Call sites between these functions are replaced with the cloned functions.
+  // Call sites of 'old_fn' are replaced with 'new_fn'. Functions between 'caller' and
+  // 'old_fn' are found by the Tarjan's SCC algorithm, which correctly handles arbitrary
+  // cycles that may occur in the call graph.
+  DCHECK(!is_compiled_);
+  DCHECK(caller->getParent() == module_);
+  DCHECK(caller != nullptr);
+  // Discovered functions whose SCC membership hasn't been concluded.
+  std::stack<llvm::Function*> scc_unknown_stack;
+  struct FunctionInfo {
+    FunctionInfo(int64_t dfs_index, int64_t low_link) : dfs_index(dfs_index),
+        low_link(low_link) {}
+    // This index numbers the functions consecutively in the order in which they are
+    // discovered.
+    const int64_t dfs_index;
+    // The lowest index of any function known to be reachable from this function.
+    // Updated when a back or cross edge is found.
+    int64_t low_link;
+    // Whether 'old_fn' is reachable from this function. This flag if final after the SCC
+    // membership of this function is concluded.
+    bool old_fn_reachable = false;
+    bool on_scc_unknown_stack = false;
+  };
+  std::unordered_map<llvm::Function*, FunctionInfo> visited_fns;
+  struct DfsStackElem {
+    DfsStackElem(llvm::Function* f, const llvm::inst_iterator& iter) : f(f), iter(iter) {}
+    // The function being searched.
+    llvm::Function* f;
+    // The next untraversed instruction in f.
+    llvm::inst_iterator iter;
+  };
+  // The callee will return its FunctionInfo to the caller through this iterator, so that
+  // the caller doesn't need to query the visited_fns map again.
+  auto popped_info = visited_fns.end();
+  // The "stack frames" for the DFS of the Tarjan's SCC algorithm.
+  std::stack<DfsStackElem> dfs_stack;
+  // Functions from which 'old_fn' is reachable, there which we need to clone.
+  vector<llvm::Function*> fns_to_clone;
+
+  dfs_stack.emplace(caller, inst_begin(caller));
+  int64_t next_dfx_index = 0;
+  while (!dfs_stack.empty()) {
+    auto top_info = visited_fns.find(dfs_stack.top().f);
+    if (top_info == visited_fns.end()) {
+      // Initialize the newly discovered function.
+      bool inserted;
+      std::tie(top_info, inserted) = visited_fns.emplace(
+          dfs_stack.top().f, FunctionInfo(next_dfx_index, next_dfx_index));
+      DCHECK(inserted);
+      scc_unknown_stack.emplace(dfs_stack.top().f);
+      ++next_dfx_index;
+    } else if (popped_info != visited_fns.end()) {
+      // A lower dfs_index may be reachable via this callee.
+      top_info->second.low_link =
+          std::min(top_info->second.low_link, popped_info->second.low_link);
+      // 'old_fn' is reachable from this function if it's reachable from its callee.
+      top_info->second.old_fn_reachable |= popped_info->second.old_fn_reachable;
+      popped_info = visited_fns.end();
+    }
+
+    // Scan for the next call function in the current function.
+    llvm::Function* next_callee = FindNextCallee(&dfs_stack.top().iter);
+    if (dfs_stack.top().iter.atEnd()) {
+      if (top_info->second.dfs_index == top_info->second.low_link) {
+        // A new SCC is discovered. 'scc_unknown_stack' contains all functions in the SCC.
+        // 'old_fn' is reachable by all of the SCC if it is reachable from this function.
+        // 'old_fn_reachable' if 'old_fn' is reachable from the root of this SCC.
+        llvm::Function* scc_member = nullptr;
+        do {
+          scc_member = scc_unknown_stack.top();
+          scc_unknown_stack.pop();
+          auto unassigned_top_info = visited_fns.find(scc_member);
+          DCHECK(unassigned_top_info != visited_fns.end());
+          unassigned_top_info->second.on_scc_unknown_stack = false;
+          if (top_info->second.old_fn_reachable) {
+            unassigned_top_info->second.old_fn_reachable = true;
+            fns_to_clone.emplace_back(scc_member);
+          }
+        } while (scc_member != dfs_stack.top().f);
+      }
+      popped_info = top_info;
+      dfs_stack.pop();
+    } else if (next_callee == old_fn) {
+      // Update 'old_fn_reachable' and don't DFS into 'old_fn'
+      top_info->second.old_fn_reachable = true;
+    } else if (next_callee->empty()) {
+      // Follow edge to callees, so long as it has a body (i.e. is not external). We can't
+      // substitute calls in external functions.
+      auto callee_info = visited_fns.find(next_callee);
+      if (callee_info == visited_fns.end()) {
+        // Found a tree edge. Push the stack and DFS from it.
+        dfs_stack.emplace(next_callee, inst_begin(next_callee));
+      } else {
+        if (callee_info->second.on_scc_unknown_stack) {
+          // Found a back edge or a cross edge to an SCC-unknown node. Take its DFS index.
+          top_info->second.low_link =
+              std::min(top_info->second.low_link, callee_info->second.dfs_index);
+        }
+        top_info->second.old_fn_reachable |= callee_info->second.old_fn_reachable;
+      }
+    }
+  }
+
+  // Ensure sure that 'old_fn' is reachable from 'caller'.
+  DCHECK(popped_info != visited_fns.end());
+  DCHECK(popped_info->second.old_fn_reachable);
+  llvm::ValueToValueMapTy fn_mapping;
+  ReplaceCallSitesRecursivelyHelper(old_fn, new_fn, fns_to_clone, &fn_mapping);
+  return llvm::cast<llvm::Function>(&*fn_mapping[caller]);
+}
+
 int LlvmCodeGen::ReplaceCallSitesWithValue(
     llvm::Function* caller, llvm::Value* replacement, const string& target_name) {
   DCHECK(!is_compiled_);
diff --git a/be/src/codegen/llvm-codegen.h b/be/src/codegen/llvm-codegen.h
index 97300a0..0aea80c 100644
--- a/be/src/codegen/llvm-codegen.h
+++ b/be/src/codegen/llvm-codegen.h
@@ -35,6 +35,7 @@
 #include <llvm/IR/Intrinsics.h>
 #include <llvm/IR/LLVMContext.h>
 #include <llvm/IR/Module.h>
+#include <llvm/IR/ValueMap.h>
 #include <llvm/Support/MemoryBuffer.h>
 #include <llvm/Support/raw_ostream.h>
 
@@ -369,6 +370,15 @@ class LlvmCodeGen {
   int ReplaceCallSites(llvm::Function* caller, llvm::Function* new_fn,
       const std::string& target_name);
 
+  /// Replace all the callsites of 'old_fn' reachable from 'caller' with 'new_fn'.
+  /// Functions that could possibly be between the call stack of 'caller' and 'old_fn'
+  /// are cloned and then modified.
+  /// TODO: if, in future, we need to replace multiple functions at the same time, we
+  /// should generalize this function to take a map of old->new functions, instead of
+  /// doing multiple passes cloning the IR.
+  llvm::Function* ReplaceCallSitesRecursively(llvm::Function* caller,
+      llvm::Function* old_fn, llvm::Function* new_fn);
+
   /// Same as ReplaceCallSites(), except replaces the function call instructions with the
   /// boolean value 'constant'.
   int ReplaceCallSitesWithBoolConst(llvm::Function* caller, bool constant,
@@ -693,6 +703,14 @@ class LlvmCodeGen {
   static void FindCallSites(llvm::Function* caller, const std::string& target_name,
       std::vector<llvm::CallInst*>* results);
 
+  /// Clone all the functions in 'fn_to_clone', replace 'old_fn' with 'new_fn' in
+  /// cloned versions of 'fns_to_clone', update any callsites in the callgraph
+  /// of 'fns_to_clone' and return a map of the original functions to cloned functions
+  /// in 'cloned_fn_map'.
+  void ReplaceCallSitesRecursivelyHelper(llvm::Function* old_fn, llvm::Function* new_fn,
+      const vector<llvm::Function*>& fns_to_clone,
+      llvm::ValueMap<const llvm::Value*, llvm::WeakTrackingVH>* cloned_fn_map);
+
   /// This function parses the function body of the given function 'fn' and materializes
   /// any functions called by it.
   Status MaterializeCallees(llvm::Function* fn);
diff --git a/be/src/exec/exchange-node.cc b/be/src/exec/exchange-node.cc
index 04e7de0..e6c7328 100644
--- a/be/src/exec/exchange-node.cc
+++ b/be/src/exec/exchange-node.cc
@@ -96,6 +96,8 @@ Status ExchangeNode::Prepare(RuntimeState* state) {
   if (is_merging_) {
     less_than_.reset(
         new TupleRowComparator(ordering_exprs_, is_asc_order_, nulls_first_));
+    RETURN_IF_ERROR(less_than_->Prepare(
+        pool_, state, expr_perm_pool(), expr_results_pool()));
     state->CheckAndAddCodegenDisabledMessage(runtime_profile());
   }
   return Status::OK();
@@ -105,7 +107,6 @@ void ExchangeNode::Codegen(RuntimeState* state) {
   DCHECK(state->ShouldCodegen());
   ExecNode::Codegen(state);
   if (IsNodeCodegenDisabled()) return;
-
   if (is_merging_) {
     Status codegen_status = less_than_->Codegen(state);
     runtime_profile()->AddCodegenMsg(codegen_status.ok(), codegen_status);
@@ -120,9 +121,8 @@ Status ExchangeNode::Open(RuntimeState* state) {
   if (is_merging_) {
     // CreateMerger() will populate its merging heap with batches from the stream_recvr_,
     // so it is not necessary to call FillInputRowBatch().
-    RETURN_IF_ERROR(
-        less_than_->Open(pool_, state, expr_perm_pool(), expr_results_pool()));
-    RETURN_IF_ERROR(stream_recvr_->CreateMerger(*less_than_.get()));
+    RETURN_IF_ERROR(less_than_->Open(state));
+    RETURN_IF_ERROR(stream_recvr_->CreateMerger(*less_than_));
   } else {
     RETURN_IF_ERROR(FillInputRowBatch(state));
   }
diff --git a/be/src/exec/topn-node.cc b/be/src/exec/topn-node.cc
index 66c647e..f5d3e54 100644
--- a/be/src/exec/topn-node.cc
+++ b/be/src/exec/topn-node.cc
@@ -78,6 +78,7 @@ Status TopNNode::Prepare(RuntimeState* state) {
       expr_perm_pool(), expr_results_pool(), &output_tuple_expr_evals_));
   tuple_row_less_than_.reset(
       new TupleRowComparator(ordering_exprs_, is_asc_order_, nulls_first_));
+  tuple_row_less_than_->Prepare(pool_, state, expr_perm_pool(), expr_results_pool());
   output_tuple_desc_ = row_descriptor_.tuple_descriptors()[0];
   insert_batch_timer_ = ADD_TIMER(runtime_profile(), "InsertBatchTime");
   state->CheckAndAddCodegenDisabledMessage(runtime_profile());
@@ -92,45 +93,48 @@ void TopNNode::Codegen(RuntimeState* state) {
   if (IsNodeCodegenDisabled()) return;
 
   LlvmCodeGen* codegen = state->codegen();
-  DCHECK(codegen != NULL);
+  DCHECK(codegen != nullptr);
+
+  // No need to copy the function because ReplaceCallSitesRecursively() will copy-on-write
+  // it.
+  llvm::Function* insert_batch_fn =
+      codegen->GetFunction(IRFunction::TOPN_NODE_INSERT_BATCH, false);
+  DCHECK(insert_batch_fn != nullptr);
+
+  // Generate two MaterializeExprs() functions, one using tuple_pool_ and
+  // one with no pool.
+  DCHECK(output_tuple_desc_ != nullptr);
+  llvm::Function* materialize_exprs_tuple_pool_fn;
+  llvm::Function* materialize_exprs_no_pool_fn;
+  llvm::Function* tuple_compare_fn;
+
+  Status codegen_status = Tuple::CodegenMaterializeExprs(codegen, false,
+      *output_tuple_desc_, output_tuple_exprs_,
+      true, &materialize_exprs_tuple_pool_fn);
 
-  // TODO: inline tuple_row_less_than_->Compare()
-  Status codegen_status = tuple_row_less_than_->Codegen(state);
   if (codegen_status.ok()) {
-    llvm::Function* insert_batch_fn =
-        codegen->GetFunction(IRFunction::TOPN_NODE_INSERT_BATCH, true);
-    DCHECK(insert_batch_fn != NULL);
-
-    // Generate two MaterializeExprs() functions, one using tuple_pool_ and
-    // one with no pool.
-    DCHECK(output_tuple_desc_ != NULL);
-    llvm::Function* materialize_exprs_tuple_pool_fn;
-    llvm::Function* materialize_exprs_no_pool_fn;
-
-    codegen_status = Tuple::CodegenMaterializeExprs(codegen, false,
-        *output_tuple_desc_, output_tuple_exprs_,
-        true, &materialize_exprs_tuple_pool_fn);
-
-    if (codegen_status.ok()) {
-      codegen_status = Tuple::CodegenMaterializeExprs(codegen, false,
-          *output_tuple_desc_, output_tuple_exprs_,
-          false, &materialize_exprs_no_pool_fn);
-
-      if (codegen_status.ok()) {
-        int replaced = codegen->ReplaceCallSites(insert_batch_fn,
-            materialize_exprs_tuple_pool_fn, Tuple::MATERIALIZE_EXPRS_SYMBOL);
-        DCHECK_REPLACE_COUNT(replaced, 1) << LlvmCodeGen::Print(insert_batch_fn);
-
-        replaced = codegen->ReplaceCallSites(insert_batch_fn,
-            materialize_exprs_no_pool_fn, Tuple::MATERIALIZE_EXPRS_NULL_POOL_SYMBOL);
-        DCHECK_REPLACE_COUNT(replaced, 1) << LlvmCodeGen::Print(insert_batch_fn);
-
-        insert_batch_fn = codegen->FinalizeFunction(insert_batch_fn);
-        DCHECK(insert_batch_fn != NULL);
-        codegen->AddFunctionToJit(insert_batch_fn,
-            reinterpret_cast<void**>(&codegend_insert_batch_fn_));
-      }
-    }
+    codegen_status = Tuple::CodegenMaterializeExprs(codegen, false, *output_tuple_desc_,
+        output_tuple_exprs_, false, &materialize_exprs_no_pool_fn);
+  }
+  if (codegen_status.ok()) {
+    codegen_status = tuple_row_less_than_->CodegenCompare(codegen, &tuple_compare_fn);
+  }
+  if (codegen_status.ok()) {
+    insert_batch_fn = codegen->ReplaceCallSitesRecursively(insert_batch_fn,
+        codegen->GetFunction(IRFunction::COMPARE_INTERPRETED, false), tuple_compare_fn);
+
+    int replaced = codegen->ReplaceCallSites(insert_batch_fn,
+        materialize_exprs_tuple_pool_fn, Tuple::MATERIALIZE_EXPRS_SYMBOL);
+    DCHECK_REPLACE_COUNT(replaced, 1) << LlvmCodeGen::Print(insert_batch_fn);
+
+    replaced = codegen->ReplaceCallSites(insert_batch_fn,
+        materialize_exprs_no_pool_fn, Tuple::MATERIALIZE_EXPRS_NULL_POOL_SYMBOL);
+    DCHECK_REPLACE_COUNT(replaced, 1) << LlvmCodeGen::Print(insert_batch_fn);
+
+    insert_batch_fn = codegen->FinalizeFunction(insert_batch_fn);
+    DCHECK(insert_batch_fn != nullptr);
+    codegen->AddFunctionToJit(insert_batch_fn,
+        reinterpret_cast<void**>(&codegend_insert_batch_fn_));
   }
   runtime_profile()->AddCodegenMsg(codegen_status.ok(), codegen_status);
 }
@@ -139,8 +143,7 @@ Status TopNNode::Open(RuntimeState* state) {
   SCOPED_TIMER(runtime_profile_->total_time_counter());
   ScopedOpenEventAdder ea(this);
   RETURN_IF_ERROR(ExecNode::Open(state));
-  RETURN_IF_ERROR(
-      tuple_row_less_than_->Open(pool_, state, expr_perm_pool(), expr_results_pool()));
+  RETURN_IF_ERROR(tuple_row_less_than_->Open(state));
   RETURN_IF_ERROR(ScalarExprEvaluator::Open(output_tuple_expr_evals_, state));
   RETURN_IF_CANCELLED(state);
   RETURN_IF_ERROR(QueryMaintenance(state));
diff --git a/be/src/runtime/data-stream-test.cc b/be/src/runtime/data-stream-test.cc
index 648356c..5ed9f76 100644
--- a/be/src/runtime/data-stream-test.cc
+++ b/be/src/runtime/data-stream-test.cc
@@ -342,8 +342,9 @@ class DataStreamTest : public testing::Test {
     ordering_exprs_.push_back(lhs_slot);
     less_than_ = obj_pool_.Add(new TupleRowComparator(ordering_exprs_,
         is_asc_, nulls_first_));
-    ASSERT_OK(less_than_->Open(
+    ASSERT_OK(less_than_->Prepare(
         &obj_pool_, runtime_state_.get(), mem_pool_.get(), mem_pool_.get()));
+    ASSERT_OK(less_than_->Open(runtime_state_.get()));
   }
 
   // Create batch_, but don't fill it with data yet. Assumes we created row_desc_.
diff --git a/be/src/runtime/sorted-run-merger.cc b/be/src/runtime/sorted-run-merger.cc
index 64feeb7..c874e2f 100644
--- a/be/src/runtime/sorted-run-merger.cc
+++ b/be/src/runtime/sorted-run-merger.cc
@@ -109,7 +109,7 @@ void SortedRunMerger::Heapify(int parent_index) {
   int least_child;
   // Find the least child of parent.
   if (right_index >= min_heap_.size() ||
-      comparator_.Less(
+      comparator_.MaybeCodegenedLess(
           min_heap_[left_index]->current_row(), min_heap_[right_index]->current_row())) {
     least_child = left_index;
   } else {
@@ -118,7 +118,7 @@ void SortedRunMerger::Heapify(int parent_index) {
 
   // If the parent is out of place, swap it with the least child and invoke
   // Heapify recursively.
-  if (comparator_.Less(min_heap_[least_child]->current_row(),
+  if (comparator_.MaybeCodegenedLess(min_heap_[least_child]->current_row(),
           min_heap_[parent_index]->current_row())) {
     iter_swap(min_heap_.begin() + least_child, min_heap_.begin() + parent_index);
     Heapify(least_child);
diff --git a/be/src/runtime/sorter-internal.h b/be/src/runtime/sorter-internal.h
index ea8275a..68639df 100644
--- a/be/src/runtime/sorter-internal.h
+++ b/be/src/runtime/sorter-internal.h
@@ -433,6 +433,9 @@ class Sorter::TupleSorter {
   /// query is cancelled.
   Status Sort(Run* run);
 
+  // Codegen codegened_sort_helper_. It will be called in the place of SortHelper().
+  Status Codegen(llvm::Function* ComparatorLessFn, RuntimeState* state);
+
  private:
   static const int INSERTION_THRESHOLD = 16;
 
@@ -464,6 +467,8 @@ class Sorter::TupleSorter {
   /// high: Mersenne Twister should be more than adequate.
   std::mt19937_64 rng_;
 
+  Status (*codegened_sort_helper_) (TupleSorter*, TupleIterator, TupleIterator);
+
   /// Wrapper around comparator_.Less(). Also call expr_results_pool_.Clear()
   /// on every 'state_->batch_size()' invocations of comparator_.Less(). Returns true
   /// if 'lhs' is less than 'rhs'.
diff --git a/be/src/runtime/sorter-ir.cc b/be/src/runtime/sorter-ir.cc
index 4da6bdd..0496347 100644
--- a/be/src/runtime/sorter-ir.cc
+++ b/be/src/runtime/sorter-ir.cc
@@ -76,6 +76,7 @@ Status Sorter::TupleSorter::Partition(TupleIterator begin,
     TupleIterator end, const Tuple* pivot, TupleIterator* cut) {
   // Hoist member variable lookups out of loop to avoid extra loads inside loop.
   Run* run = run_;
+  // TODO: codegen tuple_size? Yes! That would reduce memcpy() to fixed size.
   int tuple_size = tuple_size_;
   Tuple* temp_tuple = reinterpret_cast<Tuple*>(temp_tuple_buffer_);
   Tuple* swap_tuple = reinterpret_cast<Tuple*>(swap_buffer_);
@@ -195,6 +196,7 @@ Tuple* Sorter::TupleSorter::SelectPivot(TupleIterator begin, TupleIterator end)
   // less than 1%. Since selection is random each time, the chance of repeatedly picking
   // bad pivots decreases exponentialy and becomes negligibly small after a few
   // iterations.
+  // TODO: don't codegen this?
   Tuple* tuples[3];
   for (auto& tuple : tuples) {
     int64_t index = boost::uniform_int<int64_t>(begin.index(), end.index() - 1)(rng_);
diff --git a/be/src/runtime/sorter.cc b/be/src/runtime/sorter.cc
index a7a5a64..3a1566f 100644
--- a/be/src/runtime/sorter.cc
+++ b/be/src/runtime/sorter.cc
@@ -21,6 +21,8 @@
 #include <boost/random/uniform_int.hpp>
 #include <gutil/strings/substitute.h>
 
+#include "codegen/llvm-codegen.h"
+#include "common/compiler-util.h"
 #include "runtime/bufferpool/reservation-tracker.h"
 #include "runtime/exec-env.h"
 #include "runtime/mem-tracker.h"
@@ -736,7 +738,8 @@ Sorter::TupleSorter::TupleSorter(Sorter* parent, const TupleRowComparator& comp,
     tuple_size_(tuple_size),
     comparator_(comp),
     num_comparisons_till_free_(state->batch_size()),
-    state_(state) {
+    state_(state),
+    codegened_sort_helper_(nullptr) {
   temp_tuple_buffer_ = new uint8_t[tuple_size];
   swap_buffer_ = new uint8_t[tuple_size];
 }
@@ -750,11 +753,32 @@ Status Sorter::TupleSorter::Sort(Run* run) {
   DCHECK(run->is_finalized());
   DCHECK(!run->is_sorted());
   run_ = run;
-  RETURN_IF_ERROR(SortHelper(TupleIterator::Begin(run_), TupleIterator::End(run_)));
+  if (codegened_sort_helper_ != nullptr) {
+    RETURN_IF_ERROR(codegened_sort_helper_(this,
+        TupleIterator::Begin(run_), TupleIterator::End(run_)));
+  } else {
+    RETURN_IF_ERROR(SortHelper(TupleIterator::Begin(run_), TupleIterator::End(run_)));
+  }
+
   run_->set_sorted();
   return Status::OK();
 }
 
+Status Sorter::TupleSorter::Codegen(llvm::Function* compare_fn, RuntimeState* state) {
+  LlvmCodeGen* codegen = state->codegen();
+  llvm::Function* sort_helper_fn =
+      codegen->GetFunction(IRFunction::SORTER_SORTHELPER, false);
+  llvm::Function* replaced = codegen->ReplaceCallSitesRecursively(sort_helper_fn,
+      codegen->GetFunction(IRFunction::COMPARE_INTERPRETED, false), compare_fn);
+  replaced = codegen->FinalizeFunction(replaced);
+  if (replaced == nullptr) {
+    return Status("TupleSorter::Codegen(): codegen'd Compare() function failed "
+                  "verification, see log");
+  }
+  codegen->AddFunctionToJit(replaced, reinterpret_cast<void**>(&codegened_sort_helper_));
+  return Status::OK();
+}
+
 Sorter::Sorter(const std::vector<ScalarExpr*>& ordering_exprs,
       const std::vector<bool>& is_asc_order, const std::vector<bool>& nulls_first,
     const vector<ScalarExpr*>& sort_tuple_exprs, RowDescriptor* output_row_desc,
@@ -808,6 +832,7 @@ Status Sorter::Prepare(ObjectPool* obj_pool) {
         PrettyPrinter::Print(state_->query_options().max_row_size, TUnit::BYTES));
   }
   has_var_len_slots_ = sort_tuple_desc->HasVarlenSlots();
+  compare_less_than_.Prepare(&obj_pool_, state_, &expr_perm_pool_, &expr_results_pool_);
   in_mem_tuple_sorter_.reset(
       new TupleSorter(this, compare_less_than_, sort_tuple_desc->byte_size(), state_));
 
@@ -828,14 +853,16 @@ Status Sorter::Prepare(ObjectPool* obj_pool) {
 }
 
 Status Sorter::Codegen(RuntimeState* state) {
-  return compare_less_than_.Codegen(state);
+  llvm::Function* compare_fn;
+  LlvmCodeGen* codegen = state->codegen();
+  RETURN_IF_ERROR(compare_less_than_.CodegenCompare(codegen, &compare_fn));
+  return in_mem_tuple_sorter_->Codegen(compare_fn, state);
 }
 
 Status Sorter::Open() {
   DCHECK(in_mem_tuple_sorter_ != nullptr) << "Not prepared";
   DCHECK(unsorted_run_ == nullptr) << "Already open";
-  RETURN_IF_ERROR(compare_less_than_.Open(&obj_pool_, state_, &expr_perm_pool_,
-      &expr_results_pool_));
+  RETURN_IF_ERROR(compare_less_than_.Open(state_));
   TupleDescriptor* sort_tuple_desc = output_row_desc_->tuple_descriptors()[0];
   unsorted_run_ = run_pool_.Add(new Run(this, sort_tuple_desc, true));
   RETURN_IF_ERROR(unsorted_run_->Init());
diff --git a/be/src/util/tuple-row-compare.cc b/be/src/util/tuple-row-compare.cc
index f05a88e..2cbfef6 100644
--- a/be/src/util/tuple-row-compare.cc
+++ b/be/src/util/tuple-row-compare.cc
@@ -29,39 +29,40 @@
 using namespace impala;
 using namespace strings;
 
-Status TupleRowComparator::Open(ObjectPool* pool, RuntimeState* state,
+Status TupleRowComparator::Prepare(ObjectPool* pool, RuntimeState* state,
     MemPool* expr_perm_pool, MemPool* expr_results_pool) {
-  if (ordering_expr_evals_lhs_.empty()) {
-    RETURN_IF_ERROR(ScalarExprEvaluator::Create(ordering_exprs_, state, pool,
-        expr_perm_pool, expr_results_pool, &ordering_expr_evals_lhs_));
-    RETURN_IF_ERROR(ScalarExprEvaluator::Open(ordering_expr_evals_lhs_, state));
-  }
+  RETURN_IF_ERROR(ScalarExprEvaluator::Create(ordering_exprs_, state, pool,
+      expr_perm_pool, expr_results_pool, &ordering_expr_evals_lhs_));
+  RETURN_IF_ERROR(ScalarExprEvaluator::Create(ordering_exprs_, state, pool,
+      expr_perm_pool, expr_results_pool, &ordering_expr_evals_rhs_));
   DCHECK_EQ(ordering_exprs_.size(), ordering_expr_evals_lhs_.size());
-  if (ordering_expr_evals_rhs_.empty()) {
-    RETURN_IF_ERROR(ScalarExprEvaluator::Clone(pool, state, expr_perm_pool,
-        expr_results_pool, ordering_expr_evals_lhs_, &ordering_expr_evals_rhs_));
-  }
-  DCHECK_EQ(ordering_expr_evals_lhs_.size(), ordering_expr_evals_rhs_.size());
+  DCHECK_EQ(ordering_exprs_.size(), ordering_expr_evals_rhs_.size());
+  return Status::OK();
+}
+
+
+Status TupleRowComparator::Open(RuntimeState* state) {
+  RETURN_IF_ERROR(ScalarExprEvaluator::Open(ordering_expr_evals_lhs_, state));
+  RETURN_IF_ERROR(ScalarExprEvaluator::Open(ordering_expr_evals_rhs_, state));
   return Status::OK();
 }
 
 void TupleRowComparator::Close(RuntimeState* state) {
-  ScalarExprEvaluator::Close(ordering_expr_evals_rhs_, state);
   ScalarExprEvaluator::Close(ordering_expr_evals_lhs_, state);
+  ScalarExprEvaluator::Close(ordering_expr_evals_rhs_, state);
 }
 
 int TupleRowComparator::CompareInterpreted(
     const TupleRow* lhs, const TupleRow* rhs) const {
   DCHECK_EQ(ordering_exprs_.size(), ordering_expr_evals_lhs_.size());
-  DCHECK_EQ(ordering_expr_evals_lhs_.size(), ordering_expr_evals_rhs_.size());
-  for (int i = 0; i < ordering_expr_evals_lhs_.size(); ++i) {
+  for (int i = 0; i < ordering_exprs_.size(); ++i) {
     void* lhs_value = ordering_expr_evals_lhs_[i]->GetValue(lhs);
     void* rhs_value = ordering_expr_evals_rhs_[i]->GetValue(rhs);
 
     // The sort order of NULLs is independent of asc/desc.
-    if (lhs_value == NULL && rhs_value == NULL) continue;
-    if (lhs_value == NULL && rhs_value != NULL) return nulls_first_[i];
-    if (lhs_value != NULL && rhs_value == NULL) return -nulls_first_[i];
+    if (lhs_value == nullptr && rhs_value == nullptr) continue;
+    if (lhs_value == nullptr) return nulls_first_[i];
+    if (rhs_value == nullptr) return -nulls_first_[i];
 
     int result = RawValue::Compare(lhs_value, rhs_value, ordering_exprs_[i]->type());
     if (!is_asc_[i]) result = -result;
@@ -76,44 +77,33 @@ Status TupleRowComparator::Codegen(RuntimeState* state) {
   LlvmCodeGen* codegen = state->codegen();
   DCHECK(codegen != NULL);
   RETURN_IF_ERROR(CodegenCompare(codegen, &fn));
-  codegend_compare_fn_ = state->obj_pool()->Add(new CompareFn);
-  codegen->AddFunctionToJit(fn, reinterpret_cast<void**>(codegend_compare_fn_));
+  codegen->AddFunctionToJit(fn, reinterpret_cast<void**>(&codegend_compare_fn_));
   return Status::OK();
 }
 
-// Codegens an unrolled version of Compare(). Uses codegen'd key exprs and injects
-// nulls_first_ and is_asc_ values.
+
+// Codegens an unrolled version of CompareInterpreted(). Uses codegen'd key exprs and
+// injects nulls_first_ and is_asc_ values.
 //
 // Example IR for comparing an int column then a float column:
 //
-// ; Function Attrs: alwaysinline
-// define i32 @Compare(%"class.impala::ScalarExprEvaluator"**
-//                         %ordering_expr_evals_lhs,
-//                     %"class.impala::ScalarExprEvaluator"**
-//                         %ordering_expr_evals_rhs,
-//                     %"class.impala::TupleRow"* %lhs,
-//                     %"class.impala::TupleRow"* %rhs) #20 {
+// define i32 @Compare(%"struct.impala::TupleRowComparator"* %this,
+//     %"class.impala::TupleRow"* %lhs, %"class.impala::TupleRow"* %rhs) #40 {
 // entry:
-//   %type13 = alloca %"struct.impala::ColumnType"
 //   %0 = alloca float
 //   %1 = alloca float
-//   %type = alloca %"struct.impala::ColumnType"
 //   %2 = alloca i32
 //   %3 = alloca i32
-//   %4 = getelementptr %"class.impala::ScalarExprEvaluator"**
-//            %ordering_expr_evals_lhs, i32 0
-//   %5 = load %"class.impala::ScalarExprEvaluator"** %4
-//   %lhs_value = call i64 @GetSlotRef(
-//       %"class.impala::ScalarExprEvaluator"* %5, %"class.impala::TupleRow"* %lhs)
-//   %6 = getelementptr %"class.impala::ScalarExprEvaluator"**
-//            %ordering_expr_evals_rhs, i32 0
-//   %7 = load %"class.impala::ScalarExprEvaluator"** %6
-//   %rhs_value = call i64 @GetSlotRef(
-//       %"class.impala::ScalarExprEvaluator"* %7, %"class.impala::TupleRow"* %rhs)
+//   %lhs_value = call i64 @GetSlotRef(%"class.impala::ScalarExprEvaluator"* inttoptr
+//       (i64 165283904 to %"class.impala::ScalarExprEvaluator"*),
+//       %"class.impala::TupleRow"* %lhs)
+//   %rhs_value = call i64 @GetSlotRef(%"class.impala::ScalarExprEvaluator"* inttoptr
+//       (i64 165284288 to %"class.impala::ScalarExprEvaluator"*),
+//       %"class.impala::TupleRow"* %rhs)
 //   %is_null = trunc i64 %lhs_value to i1
 //   %is_null1 = trunc i64 %rhs_value to i1
 //   %both_null = and i1 %is_null, %is_null1
-//   br i1 %both_null, label %next_key, label %non_null
+//      br i1 %both_null, label %next_key, label %non_null
 //
 // non_null:                                         ; preds = %entry
 //   br i1 %is_null, label %lhs_null, label %lhs_non_null
@@ -128,76 +118,64 @@ Status TupleRowComparator::Codegen(RuntimeState* state) {
 //   ret i32 -1
 //
 // rhs_non_null:                                     ; preds = %lhs_non_null
-//   %8 = ashr i64 %lhs_value, 32
-//   %9 = trunc i64 %8 to i32
-//   store i32 %9, i32* %3
-//   %10 = bitcast i32* %3 to i8*
-//   %11 = ashr i64 %rhs_value, 32
-//   %12 = trunc i64 %11 to i32
-//   store i32 %12, i32* %2
-//   %13 = bitcast i32* %2 to i8*
-//   store %"struct.impala::ColumnType" { i32 5, i32 -1, i32 -1, i32 -1,
-//                                        %"class.std::vector.44" zeroinitializer,
-//                                        %"class.std::vector.49" zeroinitializer },
-//         %"struct.impala::ColumnType"* %type
+//   %4 = ashr i64 %lhs_value, 32
+//   %5 = trunc i64 %4 to i32
+//   store i32 %5, i32* %3
+//   %6 = bitcast i32* %3 to i8*
+//   %7 = ashr i64 %rhs_value, 32
+//   %8 = trunc i64 %7 to i32
+//   store i32 %8, i32* %2
+//   %9 = bitcast i32* %2 to i8*
 //   %result = call i32 @_ZN6impala8RawValue7CompareEPKvS2_RKNS_10ColumnTypeE(
-//       i8* %10, i8* %13, %"struct.impala::ColumnType"* %type)
-//   %14 = icmp ne i32 %result, 0
-//   br i1 %14, label %result_nonzero, label %next_key
+//       i8* %6, i8* %9, %"struct.impala::ColumnType"* @type)
+//   %10 = icmp ne i32 %result, 0
+//   br i1 %10, label %result_nonzero, label %next_key
 //
 // result_nonzero:                                   ; preds = %rhs_non_null
 //   ret i32 %result
 //
 // next_key:                                         ; preds = %rhs_non_null, %entry
-//   %15 = getelementptr %"class.impala::ScalarExprEvaluator"**
-//             %ordering_expr_evals_lhs, i32 1
-//   %16 = load %"class.impala::ScalarExprEvaluator"** %15
-//   %lhs_value3 = call i64 @GetSlotRef1(
-//       %"class.impala::ScalarExprEvaluator"* %16, %"class.impala::TupleRow"* %lhs)
-//   %17 = getelementptr %"class.impala::ScalarExprEvaluator"**
-//            %ordering_expr_evals_rhs, i32 1
-//   %18 = load %"class.impala::ScalarExprEvaluator"** %17
-//   %rhs_value4 = call i64 @GetSlotRef1(
-//       %"class.impala::ScalarExprEvaluator"* %18, %"class.impala::TupleRow"* %rhs)
+//   %lhs_value3 = call i64 @GetSlotRef.1(%"class.impala::ScalarExprEvaluator"* inttoptr
+//       (i64 165284096 to %"class.impala::ScalarExprEvaluator"*),
+//       %"class.impala::TupleRow"* %lhs)
+//   %rhs_value4 = call i64 @GetSlotRef.1(%"class.impala::ScalarExprEvaluator"* inttoptr
+//       (i64 165284480 to %"class.impala::ScalarExprEvaluator"*),
+//       %"class.impala::TupleRow"* %rhs)
 //   %is_null5 = trunc i64 %lhs_value3 to i1
 //   %is_null6 = trunc i64 %rhs_value4 to i1
-//   %both_null8 = and i1 %is_null5, %is_null6
-//   br i1 %both_null8, label %next_key2, label %non_null7
+//   %both_null7 = and i1 %is_null5, %is_null6
+//   br i1 %both_null7, label %next_key2, label %non_null8
 //
-// non_null7:                                        ; preds = %next_key
+// non_null8:                                        ; preds = %next_key
 //   br i1 %is_null5, label %lhs_null9, label %lhs_non_null10
 //
-// lhs_null9:                                        ; preds = %non_null7
+// lhs_null9:                                        ; preds = %non_null8
 //   ret i32 1
 //
-// lhs_non_null10:                                   ; preds = %non_null7
+// lhs_non_null10:                                   ; preds = %non_null8
 //   br i1 %is_null6, label %rhs_null11, label %rhs_non_null12
 //
 // rhs_null11:                                       ; preds = %lhs_non_null10
 //   ret i32 -1
 //
 // rhs_non_null12:                                   ; preds = %lhs_non_null10
-//   %19 = ashr i64 %lhs_value3, 32
-//   %20 = trunc i64 %19 to i32
-//   %21 = bitcast i32 %20 to float
-//   store float %21, float* %1
-//   %22 = bitcast float* %1 to i8*
-//   %23 = ashr i64 %rhs_value4, 32
-//   %24 = trunc i64 %23 to i32
-//   %25 = bitcast i32 %24 to float
-//   store float %25, float* %0
-//   %26 = bitcast float* %0 to i8*
-//   store %"struct.impala::ColumnType" { i32 7, i32 -1, i32 -1, i32 -1,
-//                                        %"class.std::vector.44" zeroinitializer,
-//                                        %"class.std::vector.49" zeroinitializer },
-//         %"struct.impala::ColumnType"* %type13
-//   %result14 = call i32 @_ZN6impala8RawValue7CompareEPKvS2_RKNS_10ColumnTypeE(
-//       i8* %22, i8* %26, %"struct.impala::ColumnType"* %type13)
-//   %27 = icmp ne i32 %result14, 0
-//   br i1 %27, label %result_nonzero15, label %next_key2
+//   %11 = ashr i64 %lhs_value3, 32
+//   %12 = trunc i64 %11 to i32
+//   %13 = bitcast i32 %12 to float
+//   store float %13, float* %1
+//   %14 = bitcast float* %1 to i8*
+//   %15 = ashr i64 %rhs_value4, 32
+//   %16 = trunc i64 %15 to i32
+//   %17 = bitcast i32 %16 to float
+//   store float %17, float* %0
+//   %18 = bitcast float* %0 to i8*
+//   %result13 = call i32 @_ZN6impala8RawValue7CompareEPKvS2_RKNS_10ColumnTypeE(
+//       i8* %14, i8* %18, %"struct.impala::ColumnType"* @type.2)
+//   %19 = icmp ne i32 %result13, 0
+//   br i1 %19, label %result_nonzero14, label %next_key2
 //
-// result_nonzero15:                                 ; preds = %rhs_non_null12
-//   ret i32 %result14
+// result_nonzero14:                                 ; preds = %rhs_non_null12
+//   ret i32 %result13
 //
 // next_key2:                                        ; preds = %rhs_non_null12, %next_key
 //   ret i32 0
@@ -214,55 +192,48 @@ Status TupleRowComparator::CodegenCompare(LlvmCodeGen* codegen, llvm::Function**
     }
   }
 
-  // Construct function signature (note that this is different than the interpreted
-  // Compare() function signature):
-  // int Compare(ScalarExprEvaluator** ordering_expr_evals_lhs,
-  //     ScalarExprEvaluator** ordering_expr_evals_rhs,
-  //     TupleRow* lhs, TupleRow* rhs)
-  llvm::PointerType* expr_evals_type =
-      codegen->GetStructPtrPtrType<ScalarExprEvaluator>();
+  // Construct function signature):
+  // int Compare(TupleRowComparator* this, TupleRow* lhs, TupleRow* rhs)
   llvm::PointerType* tuple_row_type = codegen->GetStructPtrType<TupleRow>();
   LlvmCodeGen::FnPrototype prototype(codegen, "Compare", codegen->i32_type());
-  prototype.AddArgument("ordering_expr_evals_lhs", expr_evals_type);
-  prototype.AddArgument("ordering_expr_evals_rhs", expr_evals_type);
+  // 'this' is used to comply with the signature of CompareInterpreted() and is not used.
+  prototype.AddArgument("this", codegen->GetStructPtrType<TupleRowComparator>());
   prototype.AddArgument("lhs", tuple_row_type);
   prototype.AddArgument("rhs", tuple_row_type);
 
   LlvmBuilder builder(context);
-  llvm::Value* args[4];
+  llvm::Value* args[3];
   *fn = prototype.GeneratePrototype(&builder, args);
-  llvm::Value* lhs_evals_arg = args[0];
-  llvm::Value* rhs_evals_arg = args[1];
-  llvm::Value* lhs_arg = args[2];
-  llvm::Value* rhs_arg = args[3];
+  llvm::Value* lhs_arg = args[1];
+  llvm::Value* rhs_arg = args[2];
+
+  llvm::PointerType* expr_eval_type = codegen->GetStructPtrType<ScalarExprEvaluator>();
 
   // Unrolled loop over each key expr
   for (int i = 0; i < ordering_exprs.size(); ++i) {
     // The start of the next key expr after this one. Used to implement "continue" logic
     // in the unrolled loop.
     llvm::BasicBlock* next_key_block = llvm::BasicBlock::Create(context, "next_key", *fn);
-
-    // Call key_fns[i](ordering_expr_evals_lhs[i], lhs_arg)
-    llvm::Value* lhs_eval = codegen->CodegenArrayAt(&builder, lhs_evals_arg, i);
-    llvm::Value* lhs_args[] = {lhs_eval, lhs_arg};
+    llvm::Value* eval_lhs = codegen->CastPtrToLlvmPtr(expr_eval_type,
+        ordering_expr_evals_lhs_[i]);
+    llvm::Value* eval_rhs = codegen->CastPtrToLlvmPtr(expr_eval_type,
+        ordering_expr_evals_rhs_[i]);
+    // Call key_fns[i](ordering_expr_evals_[i], lhs_arg)
     CodegenAnyVal lhs_value = CodegenAnyVal::CreateCallWrapped(codegen, &builder,
-        ordering_exprs[i]->type(), key_fns[i], lhs_args, "lhs_value");
-
-    // Call key_fns[i](ordering_expr_evals_rhs[i], rhs_arg)
-    llvm::Value* rhs_eval = codegen->CodegenArrayAt(&builder, rhs_evals_arg, i);
-    llvm::Value* rhs_args[] = {rhs_eval, rhs_arg};
+        ordering_exprs[i]->type(), key_fns[i], {eval_lhs, lhs_arg}, "lhs_value");
+    // Call key_fns[i](ordering_expr_evals_[i], rhs_arg)
     CodegenAnyVal rhs_value = CodegenAnyVal::CreateCallWrapped(codegen, &builder,
-        ordering_exprs[i]->type(), key_fns[i], rhs_args, "rhs_value");
+        ordering_exprs[i]->type(), key_fns[i], {eval_rhs, rhs_arg}, "rhs_value");
 
-    // Handle NULLs if necessary
+    // Handle nullptrs if necessary
     llvm::Value* lhs_null = lhs_value.GetIsNull();
     llvm::Value* rhs_null = rhs_value.GetIsNull();
-    // if (lhs_value == NULL && rhs_value == NULL) continue;
+    // if (lhs_value == nullptr && rhs_value == nullptr) continue;
     llvm::Value* both_null = builder.CreateAnd(lhs_null, rhs_null, "both_null");
     llvm::BasicBlock* non_null_block =
         llvm::BasicBlock::Create(context, "non_null", *fn, next_key_block);
     builder.CreateCondBr(both_null, next_key_block, non_null_block);
-    // if (lhs_value == NULL && rhs_value != NULL) return nulls_first_[i];
+    // if (lhs_value == nullptr && rhs_value != nullptr) return nulls_first_[i];
     builder.SetInsertPoint(non_null_block);
     llvm::BasicBlock* lhs_null_block =
         llvm::BasicBlock::Create(context, "lhs_null", *fn, next_key_block);
@@ -271,7 +242,7 @@ Status TupleRowComparator::CodegenCompare(LlvmCodeGen* codegen, llvm::Function**
     builder.CreateCondBr(lhs_null, lhs_null_block, lhs_non_null_block);
     builder.SetInsertPoint(lhs_null_block);
     builder.CreateRet(builder.getInt32(nulls_first_[i]));
-    // if (lhs_value != NULL && rhs_value == NULL) return -nulls_first_[i];
+    // if (lhs_value != nullptr && rhs_value == nullptr) return -nulls_first_[i];
     builder.SetInsertPoint(lhs_non_null_block);
     llvm::BasicBlock* rhs_null_block =
         llvm::BasicBlock::Create(context, "rhs_null", *fn, next_key_block);
@@ -301,9 +272,11 @@ Status TupleRowComparator::CodegenCompare(LlvmCodeGen* codegen, llvm::Function**
   }
   builder.CreateRet(builder.getInt32(0));
   *fn = codegen->FinalizeFunction(*fn);
-  if (*fn == NULL) {
+  if (*fn == nullptr) {
     return Status("Codegen'd TupleRowComparator::Compare() function failed verification, "
         "see log");
   }
   return Status::OK();
 }
+
+const char* TupleRowComparator::LLVM_CLASS_NAME = "struct.impala::TupleRowComparator";
diff --git a/be/src/util/tuple-row-compare.h b/be/src/util/tuple-row-compare.h
index e1933d5..2d9c758 100644
--- a/be/src/util/tuple-row-compare.h
+++ b/be/src/util/tuple-row-compare.h
@@ -58,8 +58,7 @@ class ComparatorWrapper {
 };
 
 /// Compares two TupleRows based on a set of exprs, in order.
-class TupleRowComparator {
- public:
+struct TupleRowComparator {
   /// 'ordering_exprs': the ordering expressions for tuple comparison.
   /// 'is_asc' determines, for each expr, if it should be ascending or descending sort
   /// order.
@@ -67,61 +66,57 @@ class TupleRowComparator {
   /// other values.
   TupleRowComparator(const std::vector<ScalarExpr*>& ordering_exprs,
       const std::vector<bool>& is_asc, const std::vector<bool>& nulls_first)
-    : ordering_exprs_(ordering_exprs),
-      is_asc_(is_asc),
-      codegend_compare_fn_(nullptr) {
+    : ordering_exprs_(ordering_exprs), is_asc_(is_asc), codegend_compare_fn_(nullptr) {
     DCHECK_EQ(is_asc_.size(), ordering_exprs.size());
     for (bool null_first : nulls_first) nulls_first_.push_back(null_first ? -1 : 1);
   }
 
   /// Create the evaluators for the ordering expressions and store them in 'pool'. The
   /// evaluators use 'expr_perm_pool' and 'expr_results_pool' for permanent and result
-  /// allocations made by exprs respectively. 'state' is passed in for initialization
-  /// of the evaluator.
-  Status Open(ObjectPool* pool, RuntimeState* state, MemPool* expr_perm_pool,
+  /// allocations made by exprs respectively.
+  Status Prepare(ObjectPool* pool, RuntimeState* state, MemPool* expr_perm_pool,
       MemPool* expr_results_pool);
 
+  /// Initialize the evaluators using 'state'.
+  Status Open(RuntimeState* state);
+
   /// Release resources held by the ordering expressions' evaluators.
   void Close(RuntimeState* state);
 
-  /// Codegens a Compare() function for this comparator that is used in Compare().
-  Status Codegen(RuntimeState* state);
-
-  /// Returns a negative value if lhs is less than rhs, a positive value if lhs is
-  /// greater than rhs, or 0 if they are equal. All exprs (ordering_exprs_lhs_ and
-  /// ordering_exprs_rhs_) must have been prepared and opened before calling this,
-  /// i.e. 'sort_key_exprs' in the constructor must have been opened.
-  int ALWAYS_INLINE Compare(const TupleRow* lhs, const TupleRow* rhs) const {
-    return codegend_compare_fn_ == NULL ?
-        CompareInterpreted(lhs, rhs) :
-        (*codegend_compare_fn_)(ordering_expr_evals_lhs_.data(),
-            ordering_expr_evals_rhs_.data(), lhs, rhs);
-  }
-
   /// Returns true if lhs is strictly less than rhs.
-  /// All exprs (ordering_exprs_lhs_ and ordering_exprs_rhs_) must have been prepared
-  /// and opened before calling this.
+  /// ordering_exprs must have been prepared and opened before calling this.
   /// Force inlining because it tends not to be always inlined at callsites, even in
   /// hot loops.
   bool ALWAYS_INLINE Less(const TupleRow* lhs, const TupleRow* rhs) const {
-    return Compare(lhs, rhs) < 0;
+    return CompareInterpreted(lhs, rhs) < 0;
   }
 
   bool ALWAYS_INLINE Less(const Tuple* lhs, const Tuple* rhs) const {
-    TupleRow* lhs_row = reinterpret_cast<TupleRow*>(&lhs);
-    TupleRow* rhs_row = reinterpret_cast<TupleRow*>(&rhs);
-    return Less(lhs_row, rhs_row);
+    return Less(reinterpret_cast<TupleRow*>(&lhs), reinterpret_cast<TupleRow*>(&rhs));
   }
 
- private:
-  /// Interpreted implementation of Compare().
-  int CompareInterpreted(const TupleRow* lhs, const TupleRow* rhs) const;
-
-  /// Codegen Compare(). Returns a non-OK status if codegen is unsuccessful.
-  /// TODO: inline this at codegen'd callsites instead of indirectly calling via function
-  /// pointer.
+  /// Codegen Compare() and return it via 'fn'. The caller could use it to replace
+  /// cross-compiled CompareInterpreted().
   Status CodegenCompare(LlvmCodeGen* codegen, llvm::Function** fn);
 
+  /// Codegen Compare() and compile it into 'codegend_compare_fn_'. The resulting function
+  /// pointer can be used by MaybeCodegenedLess() in circumstances where Compare() hasn't
+  /// been inlined.
+  Status Codegen(RuntimeState* state);
+
+  /// Use codegened compare function if available. Cross-compiled callers should call
+  /// Less() directly and replace the CompareInterpreted() with a codegened Compare().
+  bool ALWAYS_INLINE MaybeCodegenedLess(const TupleRow* lhs, const TupleRow* rhs) const{
+      return codegend_compare_fn_ == nullptr ? Less(lhs, rhs) :
+          codegend_compare_fn_(this, lhs, rhs) < 0;
+  }
+
+  /// Struct name in LLVM IR.
+  static const char* LLVM_CLASS_NAME;
+
+  /// Interpreted implementation of Compare().
+  int IR_NO_INLINE CompareInterpreted(const TupleRow* lhs, const TupleRow* rhs) const;
+
   /// References to ordering expressions owned by the Exec node which owns this
   /// TupleRowComparator.
   const std::vector<ScalarExpr*>& ordering_exprs_;
@@ -134,16 +129,10 @@ class TupleRowComparator {
   const std::vector<bool>& is_asc_;
   std::vector<int8_t> nulls_first_;
 
-  /// We store a pointer to the codegen'd function pointer (adding an extra level of
-  /// indirection) so that copies of this TupleRowComparator will have the same pointer to
-  /// the codegen'd function. This is necessary because the codegen'd function pointer is
-  /// only set after the IR module is compiled. Without the indirection, if this
-  /// TupleRowComparator is copied before the module is compiled, the copy will still have
-  /// its function pointer set to NULL. The function pointer is allocated from the runtime
-  /// state's object pool so that its lifetime will be >= that of any copies.
-  typedef int (*CompareFn)(ScalarExprEvaluator* const*, ScalarExprEvaluator* const*,
-      const TupleRow*, const TupleRow*);
-  CompareFn* codegend_compare_fn_;
+  using CompareFn = int (*) (const TupleRowComparator*, const TupleRow*, const TupleRow*);
+  CompareFn codegend_compare_fn_;
+
+  DISALLOW_COPY_AND_ASSIGN(TupleRowComparator);
 };
 
 /// Compares the equality of two Tuples, going slot by slot.
diff --git a/tests/custom_cluster/test_replace_tuple_row_compare.py b/tests/custom_cluster/test_replace_tuple_row_compare.py
new file mode 100644
index 0000000..71ab47a
--- /dev/null
+++ b/tests/custom_cluster/test_replace_tuple_row_compare.py
@@ -0,0 +1,52 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import os
+import re
+
+from tests.common.custom_cluster_test_suite import CustomClusterTestSuite
+
+
+class TestReplaceTupleRowCompare(CustomClusterTestSuite):
+  """ Tests that cross-compiled function CompareInterpreted() doesn't present in the
+  optimized IR after it's replaced by the codegened Compare(). """
+  log_dir = os.getenv('LOG_DIR', "/tmp/")
+
+  @classmethod
+  def get_workload(cls):
+    return 'functional-query'
+
+  @CustomClusterTestSuite.with_args(
+      "-dump_ir -opt_module_dir=\"" + log_dir + "\"")
+  def test_replace_tuple_row_compare(self):
+    self.client.execute("SET DISABLE_CODEGEN_ROWS_THRESHOLD=0")
+    # Sort and topn multiple times to test that the replacement is COW.
+    sort_query = "select row_number() over (order by int_col, float_col), row_number() " \
+                 "over (order by float_col, int_col) from functional.alltypes"
+    topn_query = "select * from (select * from functional.alltypes order by int_col " \
+                 "limit 10) v union (select * from functional.alltypes order by " \
+                 "float_col limit 10) order by int_col, float_col limit 5"
+    for query in [sort_query, topn_query]:
+      result = self.client.execute(query)
+      assert result.success
+      coord_id = re.search("Query \(id=([^)]*)\)", result.runtime_profile).group(1)
+      ir = open(os.path.join(self.log_dir, coord_id[:-1] + "1_opt.ll")).read()
+      # If CompareInterpreted is inlined, there should still be lables named like
+      # "_ZNK6impala18TupleRowComparator18CompareInterpretedEPKNS_8TupleRowES3_.exit"
+      # in the IR. So CompareInterpreted not presenting in the text should show that it's
+      # replaced.
+      assert "Compare" in ir and "CompareInterpreted" not in ir
-- 
2.7.4

