diff --git a/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/rules/HiveRelFieldTrimmer.java b/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/rules/HiveRelFieldTrimmer.java index c570356d8b..9bf42ed384 100644 --- a/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/rules/HiveRelFieldTrimmer.java +++ b/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/rules/HiveRelFieldTrimmer.java @@ -61,6 +61,7 @@ import org.apache.calcite.util.mapping.Mapping; import org.apache.calcite.util.mapping.MappingType; import org.apache.calcite.util.mapping.Mappings; +import org.apache.hadoop.hive.metastore.api.FieldSchema; import org.apache.hadoop.hive.ql.metadata.Table; import org.apache.hadoop.hive.ql.optimizer.calcite.HiveCalciteUtil; import org.apache.hadoop.hive.ql.optimizer.calcite.RelOptHiveTable; @@ -672,10 +673,10 @@ public TrimResult trimFields(Aggregate aggregate, ImmutableBitSet fieldsUsed, Se public TrimResult trimFields(Project project, ImmutableBitSet fieldsUsed, Set extraFields) { // set columnAccessInfo for ViewColumnAuthorization - for (Ord ord : Ord.zip(project.getProjects())) { - if (fieldsUsed.get(ord.i)) { - if (this.columnAccessInfo != null && this.viewProjectToTableSchema != null - && this.viewProjectToTableSchema.containsKey(project)) { + if (this.columnAccessInfo != null && this.viewProjectToTableSchema != null + && this.viewProjectToTableSchema.containsKey(project)) { + for (Ord ord : Ord.zip(project.getProjects())) { + if (fieldsUsed.get(ord.i)) { Table tab = this.viewProjectToTableSchema.get(project); this.columnAccessInfo.add(tab.getCompleteName(), tab.getAllCols().get(ord.i).getName()); } @@ -684,10 +685,26 @@ public TrimResult trimFields(Project project, ImmutableBitSet fieldsUsed, return super.trimFields(project, fieldsUsed, extraFields); } - @Override - public TrimResult trimFields(TableScan tableAccessRel, ImmutableBitSet fieldsUsed, + public TrimResult trimFields(HiveTableScan tableAccessRel, ImmutableBitSet fieldsUsed, Set extraFields) { final TrimResult result = super.trimFields(tableAccessRel, fieldsUsed, extraFields); + if (this.columnAccessInfo != null) { + // Store information about column accessed by the table so it can be used + // to send only this information for column masking + final RelOptHiveTable tab = (RelOptHiveTable) tableAccessRel.getTable(); + final String qualifiedName = tab.getHiveTableMD().getCompleteName(); + final List allCols = tab.getHiveTableMD().getAllCols(); + final boolean insideView = tableAccessRel.isInsideView(); + fieldsUsed.asList().stream() + .filter(idx -> idx < tab.getNoOfNonVirtualCols()) + .forEach(idx -> { + if (insideView) { + columnAccessInfo.addIndirect(qualifiedName, allCols.get(idx).getName()); + } else { + columnAccessInfo.add(qualifiedName, allCols.get(idx).getName()); + } + }); + } if (fetchStats) { fetchColStats(result.getKey(), tableAccessRel, fieldsUsed, extraFields); } diff --git a/ql/src/java/org/apache/hadoop/hive/ql/parse/CalcitePlanner.java b/ql/src/java/org/apache/hadoop/hive/ql/parse/CalcitePlanner.java index 4762335a0f..2b9caac03f 100644 --- a/ql/src/java/org/apache/hadoop/hive/ql/parse/CalcitePlanner.java +++ b/ql/src/java/org/apache/hadoop/hive/ql/parse/CalcitePlanner.java @@ -307,6 +307,7 @@ private SemanticException semanticException; private boolean runCBO = true; private boolean disableSemJoinReordering = true; + private EnumSet profilesCBO; private static final CommonToken FROM_TOKEN = @@ -1780,8 +1781,8 @@ public RelNode apply(RelOptCluster cluster, RelOptSchema relOptSchema, SchemaPlu // We need to get the ColumnAccessInfo and viewToTableSchema for views. HiveRelFieldTrimmer fieldTrimmer = new HiveRelFieldTrimmer(null, - HiveRelFactories.HIVE_BUILDER.create(optCluster, null), this.columnAccessInfo, - this.viewProjectToTableSchema); + HiveRelFactories.HIVE_BUILDER.create(optCluster, null), + this.columnAccessInfo, this.viewProjectToTableSchema); fieldTrimmer.trim(calciteGenPlan); diff --git a/ql/src/java/org/apache/hadoop/hive/ql/parse/ColumnAccessInfo.java b/ql/src/java/org/apache/hadoop/hive/ql/parse/ColumnAccessInfo.java index 9fb6a4e5e2..df2ee6afbf 100644 --- a/ql/src/java/org/apache/hadoop/hive/ql/parse/ColumnAccessInfo.java +++ b/ql/src/java/org/apache/hadoop/hive/ql/parse/ColumnAccessInfo.java @@ -20,42 +20,73 @@ import org.apache.hadoop.hive.ql.metadata.VirtualColumn; -import java.util.ArrayList; -import java.util.Collections; +import com.google.common.collect.LinkedHashMultimap; +import com.google.common.collect.SetMultimap; + +import java.util.Collection; import java.util.LinkedHashMap; -import java.util.LinkedHashSet; import java.util.List; import java.util.Map; -import java.util.Set; +import java.util.Objects; +import java.util.stream.Collectors; public class ColumnAccessInfo { /** - * Map of table name to names of accessed columns + * Map of table name to names of accessed columns (directly and indirectly -through views-). */ - private final Map> tableToColumnAccessMap; + private final SetMultimap tableToColumnAccessMap; public ColumnAccessInfo() { // Must be deterministic order map for consistent q-test output across Java versions - tableToColumnAccessMap = new LinkedHashMap>(); + tableToColumnAccessMap = LinkedHashMultimap.create(); } + /** + * Adds access to column. + */ public void add(String table, String col) { - Set tableColumns = tableToColumnAccessMap.get(table); - if (tableColumns == null) { - // Must be deterministic order set for consistent q-test output across Java versions - tableColumns = new LinkedHashSet(); - tableToColumnAccessMap.put(table, tableColumns); - } - tableColumns.add(col); + tableToColumnAccessMap.put(table, new ColumnAccess(col, Access.DIRECT)); + } + + /** + * Adds indirect access to column (through view). + */ + public void addIndirect(String table, String col) { + tableToColumnAccessMap.put(table, new ColumnAccess(col, Access.INDIRECT)); } + /** + * Includes direct access. + */ public Map> getTableToColumnAccessMap() { // Must be deterministic order map for consistent q-test output across Java versions Map> mapping = new LinkedHashMap>(); - for (Map.Entry> entry : tableToColumnAccessMap.entrySet()) { - List sortedCols = new ArrayList(entry.getValue()); - Collections.sort(sortedCols); - mapping.put(entry.getKey(), sortedCols); + for (Map.Entry> entry : tableToColumnAccessMap.asMap().entrySet()) { + mapping.put( + entry.getKey(), + entry.getValue().stream() + .filter(ca -> ca.access == Access.DIRECT) + .map(ca -> ca.columnName) + .sorted() + .collect(Collectors.toList())); + } + return mapping; + } + + /** + * Includes direct and indirect access. + */ + public Map> getTableToColumnAllAccessMap() { + // Must be deterministic order map for consistent q-test output across Java versions + Map> mapping = new LinkedHashMap>(); + for (Map.Entry> entry : tableToColumnAccessMap.asMap().entrySet()) { + mapping.put( + entry.getKey(), + entry.getValue().stream() + .map(ca -> ca.columnName) + .distinct() + .sorted() + .collect(Collectors.toList())); } return mapping; } @@ -66,14 +97,50 @@ public void add(String table, String col) { * @param vc */ public void stripVirtualColumn(VirtualColumn vc) { - for (Map.Entry> e : tableToColumnAccessMap.entrySet()) { - for (String columnName : e.getValue()) { - if (vc.getName().equalsIgnoreCase(columnName)) { - e.getValue().remove(columnName); + for (Map.Entry> e : tableToColumnAccessMap.asMap().entrySet()) { + for (ColumnAccess columnAccess : e.getValue()) { + if (vc.getName().equalsIgnoreCase(columnAccess.columnName)) { + e.getValue().remove(columnAccess); break; } } } + } + /** + * Column access information. + */ + private static class ColumnAccess { + private final String columnName; + private final Access access; + + private ColumnAccess (String columnName, Access access) { + this.columnName = Objects.requireNonNull(columnName); + this.access = Objects.requireNonNull(access); + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o instanceof ColumnAccess) { + ColumnAccess other = (ColumnAccess) o; + return columnName.equals(other.columnName) + && access == other.access; + } + return false; + } + + @Override + public int hashCode() { + return Objects.hash(columnName, access); + } + + } + + private enum Access { + DIRECT, + INDIRECT } } diff --git a/ql/src/java/org/apache/hadoop/hive/ql/parse/SemanticAnalyzer.java b/ql/src/java/org/apache/hadoop/hive/ql/parse/SemanticAnalyzer.java index c6c2219968..4eecd8d27a 100644 --- a/ql/src/java/org/apache/hadoop/hive/ql/parse/SemanticAnalyzer.java +++ b/ql/src/java/org/apache/hadoop/hive/ql/parse/SemanticAnalyzer.java @@ -21,8 +21,6 @@ import static java.util.Objects.nonNull; import static org.apache.hadoop.hive.conf.HiveConf.ConfVars.HIVESTATSDBCLASS; -import com.google.common.collect.ArrayListMultimap; -import com.google.common.collect.Multimap; import java.io.FileNotFoundException; import java.io.IOException; import java.security.AccessControlException; @@ -52,7 +50,6 @@ import java.util.regex.PatternSyntaxException; import java.util.stream.Collectors; -import com.google.common.collect.Lists; import org.antlr.runtime.ClassicToken; import org.antlr.runtime.CommonToken; import org.antlr.runtime.Token; @@ -290,7 +287,10 @@ import com.google.common.base.Splitter; import com.google.common.base.Strings; +import com.google.common.collect.ArrayListMultimap; import com.google.common.collect.ImmutableMap; +import com.google.common.collect.Lists; +import com.google.common.collect.Multimap; import com.google.common.collect.Sets; import com.google.common.math.IntMath; import com.google.common.math.LongMath; @@ -12146,9 +12146,19 @@ private void walkASTMarkTABREF(TableMask tableMask, ASTNode ast, Set cte basicInfos.put(new HivePrivilegeObject(table.getDbName(), table.getTableName(), colNames), null); } } else { - List colNames = new ArrayList<>(); - List colTypes = new ArrayList<>(); - extractColumnInfos(table, colNames, colTypes); + List colNames; + List colTypes; + if (isCBOExecuted() && this.columnAccessInfo != null && + this.columnAccessInfo.getTableToColumnAllAccessMap().containsKey(table.getCompleteName())) { + colNames = this.columnAccessInfo.getTableToColumnAllAccessMap().get(table.getCompleteName()); + Map colNameToType = table.getAllCols().stream() + .collect(Collectors.toMap(FieldSchema::getName, FieldSchema::getType)); + colTypes = colNames.stream().map(colNameToType::get).collect(Collectors.toList()); + } else { + colNames = new ArrayList<>(); + colTypes = new ArrayList<>(); + extractColumnInfos(table, colNames, colTypes); + } basicInfos.put(new HivePrivilegeObject(table.getDbName(), table.getTableName(), colNames), new MaskAndFilterInfo(colTypes, additionalTabInfo.toString(), alias, astNode, table.isView(), table.isNonNative()));