Superset 如何拼接SQL 并执行的(上)
本章,我们我们将以Superset创建图表的相关接口为脉络。讲解Superset如何通过入参生成相关Dashboard并返回Dataframe。
首先,在Superset之中,存在着两种类型的图标,对应着两种类型的数据查询接口。
http://localhost:9000/superset/explore_json/
以及
http://localhost:9000/api/v1/chart/data
两者分别对应的不同的前端类型,往往Echarts类型的chart,在获取数据的时候都是走的/data接口。是一个统一的处理逻辑。而/explore_json则是一个框架代码,父类中封装多个函数,并由一个上帝类将其组装起来并调用。而不同的chart则是存在在这一个框架中的不同类。不同的类可以选择性的替换父类中的逻辑,从而实现自己的查询逻辑。
而我们也将以/data接口作为入口,先查看Echarts类型相关的chart如何进行sql拼接查询的过程。再去看看/explore_json
http://localhost:9000/api/v1/chart/data的接口位于superset/charts/data/api.py文件中。
而具体到这个接口之中,则是将传入的参数,首先将Json入参转换为了一个query_context
query_context = self._create_query_context_from_form(json_body)
command = ChartDataCommand(query_context)
command.validate()
并封装到ChartDataCommand之中。并在最后调用了该类的_get_data_response函数进行查询
return self._get_data_response(
command, form_data=form_data, datasource=query_context.datasource
)
对于转换的步骤
主要就是利用ChartDataQueryContextSchema().load(form_data)
而在Schema之中,则是利用了Schema对应工厂类进行的转换。
superset.common.query_context_factory.QueryContextFactory下的create函数
主要就是格式化入参,查看是否是创建好的slice对象。准备数据源对象。
然后封装为了command对象,并对是否有权限访问这个数据源进行了校验。
而在get_data_response方法中,则是直接调用了command.run,获取了result进行返回。
在command.run之中,主要的查询逻辑存在于了
payload = self._query_context.get_payload(
cache_query_context=cache_query_context, force_cached=force_cached
)
也就是反过来调用上面转换得到的query_context中的get_payload中方法。
在get_payload方法之中
def get_payload(
self, cache_query_context: Optional[bool] = False, force_cached: bool = False, ) -> Dict[str, Any]: “””Returns the query results with both metadata and data””” |
则是调用processor处理器中的get_payload
这又是一个障眼法,processor会在初始化的时候传入query_context,相当于主动权交给了processor。
而在get_payload之中,则是对于每个query_context中queries数组下的每一个query进行了查询。
query_results = [
get_query_results( query_obj.result_type or self._query_context.result_type, self._query_context, query_obj, force_cached, ) for query_obj in self._query_context.queries ] |
而在get_query_results函数之中,则是根据query中的result_type调用了不同函数,比如sample进行采样操作,query则只是拼接sql,result则是拼接sql并进行查询。
这里我们以result对应的函数 get_result为线索继续往下看。_get_result 调用了_get_full函数,而在get_full函数中
def _get_full(
query_context: QueryContext, query_obj: QueryObject, force_cached: Optional[bool] = False, ) -> Dict[str, Any]: datasource = _get_datasource(query_context, query_obj) result_type = query_obj.result_type or query_context.result_type payload = query_context.get_df_payload(query_obj, force_cached=force_cached) applied_template_filters = payload.get(“applied_template_filters”, []) df = payload[“df”] status = payload[“status”] if status != QueryStatus.FAILED: payload[“colnames”] = list(df.columns) payload[“indexnames”] = list(df.index) payload[“coltypes”] = extract_dataframe_dtypes(df, datasource) payload[“data”] = query_context.get_data(df) payload[“result_format”] = query_context.result_format del payload[“df”] filters = query_obj.filter filter_columns = cast(List[str], [flt.get(“col”) for flt in filters]) columns = set(datasource.column_names) applied_time_columns, rejected_time_columns = get_time_filter_status( datasource, query_obj.applied_time_extras ) payload[“applied_filters”] = [ {“column”: get_column_name(col)} for col in filter_columns if is_adhoc_column(col) or col in columns or col in applied_template_filters ] + applied_time_columns payload[“rejected_filters”] = [ {“reason”: ExtraFiltersReasonType.COL_NOT_IN_DATASOURCE, “column”: col} for col in filter_columns if not is_adhoc_column(col) and col not in columns and col not in applied_template_filters ] + rejected_time_columns if result_type == ChartDataResultType.RESULTS and status != QueryStatus.FAILED: return { “data”: payload.get(“data”), “colnames”: payload.get(“colnames”), “coltypes”: payload.get(“coltypes”), } return payload |
在上面这段代码之中,首先是利用query_context中获取到了dataframe.
payload = query_context.get_df_payload(query_obj, force_cached=force_cached)
然后是进行后缀处理。
比如获取字段名字,索引名字,字段类型,以及利用get_data函数将dataframe转换为不同的格式。
if status != QueryStatus.FAILED:
payload[“colnames”] = list(df.columns) payload[“indexnames”] = list(df.index) payload[“coltypes”] = extract_dataframe_dtypes(df, datasource) payload[“data”] = query_context.get_data(df) payload[“result_format”] = query_context.result_format del payload[“df”] |
最后则是一些filter拼接后返回。
那么就核心来看,还是存在于get_df_payload函数之中。
def get_df_payload(
self, query_obj: QueryObject, force_cached: Optional[bool] = False ) -> Dict[str, Any]: “””Handles caching around the df payload retrieval””” cache = QueryCacheManager.get( cache_key, CacheRegion.DATA, self._query_context.force, force_cached, ) if query_obj and cache_key and not cache.is_loaded: try: invalid_columns = [ col for col in get_column_names_from_columns(query_obj.columns) + get_column_names_from_metrics(query_obj.metrics or []) if ( col not in self._qc_datasource.column_names and col != DTTM_ALIAS ) ] if invalid_columns: raise QueryObjectValidationError( _( “Columns missing in dataset: %(invalid_columns)s”, invalid_columns=invalid_columns, ) ) query_result = self.get_query_result(query_obj) annotation_data = self.get_annotation_data(query_obj) cache.set_query_result( key=cache_key, query_result=query_result, annotation_data=annotation_data, force_query=self._query_context.force, timeout=self.get_cache_timeout(), datasource_uid=self._qc_datasource.uid, region=CacheRegion.DATA, ) except QueryObjectValidationError as ex: cache.error_message = str(ex) cache.status = QueryStatus.FAILED # the N-dimensional DataFrame has converteds into flat DataFrame # by `flatten operator`, “comma” in the column is escaped by `escape_separator` # the result DataFrame columns should be unescaped label_map = { unescape_separator(col): [ unescape_separator(col) for col in re.split(r”(?<!\\),\s”, col) ] for col in cache.df.columns.values } cache.df.columns = [unescape_separator(col) for col in cache.df.columns.values] return { “cache_key”: cache_key, “cached_dttm”: cache.cache_dttm, “cache_timeout”: self.get_cache_timeout(), “df”: cache.df, “applied_template_filters”: cache.applied_template_filters, “annotation_data”: cache.annotation_data, “error”: cache.error_message, “is_cached”: cache.is_cached, “query”: cache.query, “status”: cache.status, “stacktrace”: cache.stacktrace, “rowcount”: len(cache.df.index), “from_dttm”: query_obj.from_dttm, “to_dttm”: query_obj.to_dttm, “label_map”: label_map, } |
这个函数中,会首先尝试查询缓存中是否存在
如果获取不到,在调用query_result = self.get_query_result(query_obj)进行查询
查询完成之后存入缓存之中。
而从self.get_query_result函数进去。
核心的代码就几行,
query = “”
if isinstance(query_context.datasource, Query): # todo(hugh): add logic to manage all sip68 models here else: result = query_context.datasource.query(query_object.to_dict()) query = result.query + “;\n\n” df = result.df |
主要逻辑还是利用datasource进行的查询。
而在datasource的相关代码中,则是先进行了sql拼接
query_str_ext = self.get_query_str_extended(query_obj)
在之后执行sql获取到了dataframe。
df = self.database.get_df(sql, self.schema, mutator=assign_column_label)
而在拼接sql的时候,则是先拼接为SqlaQuery,再利用对应的db_engine进行转换。
def get_query_str_extended(self, query_obj: QueryObjectDict) -> QueryStringExtended:
sqlaq = self.get_sqla_query(**query_obj) sql = self.database.compile_sqla_query(sqlaq.sqla_query) sql = self._apply_cte(sql, sqlaq.cte) sql = sqlparse.format(sql, reindent=True) sql = self.mutate_query_from_config(sql) return QueryStringExtended( applied_template_filters=sqlaq.applied_template_filters, labels_expected=sqlaq.labels_expected, prequeries=sqlaq.prequeries, sql=sql, ) |
而主要的生成SqlaQuery的函数,就是相关的核心代码了。
这里我们进行拆解。
在函数之中,首先是初始化相关参数
extras = extras or {}
time_grain = extras.get(“time_grain_sqla”) template_kwargs = { “columns”: columns, “from_dttm”: from_dttm.isoformat() if from_dttm else None, “groupby”: groupby, “metrics”: metrics, “row_limit”: row_limit, “row_offset”: row_offset, “time_column”: granularity, “time_grain”: time_grain, “to_dttm”: to_dttm.isoformat() if to_dttm else None, “table_columns”: [col.column_name for col in self.columns], “filter”: filter, } columns = columns or [] groupby = groupby or [] series_column_names = utils.get_column_names(series_columns or []) # deprecated, to be removed in 2.0 if is_timeseries and timeseries_limit: series_limit = timeseries_limit series_limit_metric = series_limit_metric or timeseries_limit_metric template_kwargs.update(self.template_params_dict) extra_cache_keys: List[Any] = [] template_kwargs[“extra_cache_keys”] = extra_cache_keys removed_filters: List[str] = [] applied_template_filters: List[str] = [] template_kwargs[“removed_filters”] = removed_filters template_kwargs[“applied_filters”] = applied_template_filters template_processor = self.get_template_processor(**template_kwargs) db_engine_spec = self.db_engine_spec prequeries: List[str] = [] orderby = orderby or [] need_groupby = bool(metrics is not None or groupby) metrics = metrics or [] # For backward compatibility if granularity not in self.dttm_cols and granularity is not None: granularity = self.main_dttm_col columns_by_name: Dict[str, TableColumn] = { col.column_name: col for col in self.columns } metrics_by_name: Dict[str, SqlMetric] = {m.metric_name: m for m in self.metrics} |
这里的重点是获取到了当前查询使用到的Dataset中的column和metric。
也就是最后两行。
然后是判断是否存在要求是时间类型的图但是没有时间字段,没有metric并且没有columns和groupby。
if not granularity and is_timeseries:
raise QueryObjectValidationError( _( “Datetime column not provided as part table configuration ” “and is required by this type of chart” ) ) if not metrics and not columns and not groupby: raise QueryObjectValidationError(_(“Empty query?”)) |
然后是遍历传入的metric进行相关处理
for metric in metrics:
if utils.is_adhoc_metric(metric): assert isinstance(metric, dict) metrics_exprs.append( self.adhoc_metric_to_sqla( metric=metric, columns_by_name=columns_by_name, template_processor=template_processor, ) ) elif isinstance(metric, str) and metric in metrics_by_name: metrics_exprs.append( metrics_by_name[metric].get_sqla_col( template_processor=template_processor ) ) else: raise QueryObjectValidationError( _(“Metric ‘%(metric)s’ does not exist”, metric=metric) ) |
其中主要是adhoc_metric_to_sqla,由于metric支持传入dataset中本身的metric对象和custom sql。这里进行了判断,并进行了不同的处理
其中从传入的columns_by_name 字典中获取到字段对象
然后最重要的是利用self.sqla_aggregation字段获取获取到对应的聚合函数,处理对应字段,最后利用对应的db_engine进行编译即可。
def adhoc_metric_to_sqla(
self, metric: AdhocMetric, columns_by_name: Dict[str, TableColumn], template_processor: Optional[BaseTemplateProcessor] = None, ) -> ColumnElement: “”” :param dict metric: Adhoc metric definition label = utils.get_metric_name(metric) if expression_type == utils.AdhocMetricExpressionType.SIMPLE: metric_column = metric.get(“column”) or {} column_name = cast(str, metric_column.get(“column_name”)) table_column: Optional[TableColumn] = columns_by_name.get(column_name) if table_column: sqla_column = table_column.get_sqla_col( template_processor=template_processor ) else: sqla_column = column(column_name) sqla_metric = self.sqla_aggregations[metric[“aggregate”]](sqla_column) elif expression_type == utils.AdhocMetricExpressionType.SQL: expression = _process_sql_expression( expression=metric[“sqlExpression”], database_id=self.database_id, schema=self.schema, template_processor=template_processor, ) sqla_metric = literal_column(expression) else: raise QueryObjectValidationError(“Adhoc metric expressionType is invalid”) return self.make_sqla_column_compatible(sqla_metric, label) |
这里我们直接给出对应的sqla_aggregations字段
sqla_aggregations = {
“COUNT_DISTINCT”: lambda column_name: sa.func.COUNT(sa.distinct(column_name)), “COUNT”: sa.func.COUNT, “SUM”: sa.func.SUM, “AVG”: sa.func.AVG, “MIN”: sa.func.MIN, “MAX”: sa.func.MAX, } |
其中是直接利用了sqlachemy相关的func实现的。
如果是custom sql 则是直接parse sql之后封装为literal_column(expression)
这里循环处理完成metric之后,接下来进行orderby的处理
遍历orderby字段过程中,主要逻辑就是根据是否是聚合函数,来判断是否需要group by
在处理完成之后,获取到了need_group 字段值之后,根据是否是需要聚合,来使用groups 字段或者columns 字段来拼接select_exprs数组以及groupby_all_columns字典。
然后根据是否是时序图
来添加时间字段到select_exprs数组之中。
if is_timeseries:
timestamp = dttm_col.get_timestamp_expression( time_grain=time_grain, template_processor=template_processor ) # always put timestamp as the first column select_exprs.insert(0, timestamp) groupby_all_columns[timestamp.name] = timestamp |
以及根据是否具有时间字段,来在time_filters添加过滤器。
time_filters.append(
dttm_col.get_time_filter( start_dttm=from_dttm, end_dttm=to_dttm, template_processor=template_processor, ) ) |
在接下来去重所选字段之后,进行实际sqlaquery的拼接。
qry = sa.select(select_exprs)
首先是根据选择字段创建qry对象。
填充group by 字段
if groupby_all_columns:
qry = qry.group_by(*groupby_all_columns.values()) |
再其次是填充where和having字段
对于where字段,则是从入参中的filters数组进行取出拼接
这个数组每一个下标对应着一个字典,包含了col val op三个key
对于每一个数组,都是获取到字段sql,根据op,进行相关的sql拼接
assert isinstance(eq, (tuple, list))
if len(eq) == 0: raise QueryObjectValidationError( _(“Filter value list cannot be empty”) ) if len(eq) > len( eq_without_none := [x for x in eq if x is not None] ): is_null_cond = sqla_col.is_(None) if eq: cond = or_(is_null_cond, sqla_col.in_(eq_without_none)) else: cond = is_null_cond else: cond = sqla_col.in_(eq) if op == utils.FilterOperator.NOT_IN.value: cond = ~cond where_clause_and.append(cond) elif op == utils.FilterOperator.IS_NULL.value: where_clause_and.append(sqla_col.is_(None)) elif op == utils.FilterOperator.IS_NOT_NULL.value: where_clause_and.append(sqla_col.isnot(None)) elif op == utils.FilterOperator.IS_TRUE.value: where_clause_and.append(sqla_col.is_(True)) elif op == utils.FilterOperator.IS_FALSE.value: where_clause_and.append(sqla_col.is_(False)) else: if ( op not in { utils.FilterOperator.EQUALS.value, utils.FilterOperator.NOT_EQUALS.value, } and eq is None ): raise QueryObjectValidationError( _( “Must specify a value for filters ” “with comparison operators” ) ) if op == utils.FilterOperator.EQUALS.value: where_clause_and.append(sqla_col == eq) elif op == utils.FilterOperator.NOT_EQUALS.value: where_clause_and.append(sqla_col != eq) elif op == utils.FilterOperator.GREATER_THAN.value: where_clause_and.append(sqla_col > eq) elif op == utils.FilterOperator.LESS_THAN.value: where_clause_and.append(sqla_col < eq) elif op == utils.FilterOperator.GREATER_THAN_OR_EQUALS.value: where_clause_and.append(sqla_col >= eq) elif op == utils.FilterOperator.LESS_THAN_OR_EQUALS.value: where_clause_and.append(sqla_col <= eq) elif op == utils.FilterOperator.LIKE.value: where_clause_and.append(sqla_col.like(eq)) elif op == utils.FilterOperator.ILIKE.value: where_clause_and.append(sqla_col.ilike(eq)) elif ( op == utils.FilterOperator.TEMPORAL_RANGE.value and isinstance(eq, str) and col_obj is not None ): _since, _until = get_since_until_from_time_range( time_range=eq, time_shift=time_shift, extras=extras, ) where_clause_and.append( col_obj.get_time_filter( start_dttm=_since, end_dttm=_until, label=sqla_col.key, template_processor=template_processor, ) ) else: raise QueryObjectValidationError( _(“Invalid filter operation type: %(op)s”, op=op) ) |
在根据filters配置完成where之后,
如果用户书写了custom sql,那么会作为extra字段保存在入参之中
并且在custom sql之中存在着where和having的选项,
那么我们就要对extra字典中的where和having进行解析,分别拼接在where_clause_and数组和having_clause_and数组之中。
if extras:
where = extras.get(“where”) if where: try: where = template_processor.process_template(f”({where})”) except TemplateError as ex: raise QueryObjectValidationError( _( “Error in jinja expression in WHERE clause: %(msg)s”, msg=ex.message, ) ) from ex where = _process_sql_expression( expression=where, database_id=self.database_id, schema=self.schema, ) where_clause_and += [self.text(where)] having = extras.get(“having”) if having: try: having = template_processor.process_template(f”({having})”) except TemplateError as ex: raise QueryObjectValidationError( _( “Error in jinja expression in HAVING clause: %(msg)s”, msg=ex.message, ) ) from ex having = _process_sql_expression( expression=having, database_id=self.database_id, schema=self.schema, ) having_clause_and += [self.text(having)] |
之后给qry添加上where 和 having条件
if granularity:
qry = qry.where(and_(*(time_filters + where_clause_and))) else: qry = qry.where(and_(*where_clause_and)) qry = qry.having(and_(*having_clause_and)) |
之后就是利用上面拼接好的order by 进行order by的拼接。
以及 limit 和offset 的拼接
if row_limit:
qry = qry.limit(row_limit) if row_offset: qry = qry.offset(row_offset) |
并且根据dataset相关的定义,拼接上from
tbl, cte = self.get_from_clause(template_processor)
qry = qry.select_from(tbl) |
这样就将这样的一个qry拼接完成,返回给上一级,利用db_engine 编译出sql
并利用sql获取到dataframe了,这样就是echarts相关查询数据全过程。
在下半部分,我们将讲述针对superset如何针对其他类型的图,书写出一个查询数据框架的。