Superset 如何拼接SQL 并执行的(上)

本章,我们我们将以Superset创建图表的相关接口为脉络。讲解Superset如何通过入参生成相关Dashboard并返回Dataframe。

首先,在Superset之中,存在着两种类型的图标,对应着两种类型的数据查询接口。

http://localhost:9000/superset/explore_json/

以及

http://localhost:9000/api/v1/chart/data

两者分别对应的不同的前端类型,往往Echarts类型的chart,在获取数据的时候都是走的/data接口。是一个统一的处理逻辑。而/explore_json则是一个框架代码,父类中封装多个函数,并由一个上帝类将其组装起来并调用。而不同的chart则是存在在这一个框架中的不同类。不同的类可以选择性的替换父类中的逻辑,从而实现自己的查询逻辑。

而我们也将以/data接口作为入口,先查看Echarts类型相关的chart如何进行sql拼接查询的过程。再去看看/explore_json

http://localhost:9000/api/v1/chart/data的接口位于superset/charts/data/api.py文件中。

而具体到这个接口之中,则是将传入的参数,首先将Json入参转换为了一个query_context

query_context = self._create_query_context_from_form(json_body)

command = ChartDataCommand(query_context)

command.validate()

并封装到ChartDataCommand之中。并在最后调用了该类的_get_data_response函数进行查询

return self._get_data_response(

command, form_data=form_data, datasource=query_context.datasource

)

对于转换的步骤

主要就是利用ChartDataQueryContextSchema().load(form_data)

而在Schema之中,则是利用了Schema对应工厂类进行的转换。

superset.common.query_context_factory.QueryContextFactory下的create函数

主要就是格式化入参,查看是否是创建好的slice对象。准备数据源对象。

然后封装为了command对象,并对是否有权限访问这个数据源进行了校验。

而在get_data_response方法中,则是直接调用了command.run,获取了result进行返回。

在command.run之中,主要的查询逻辑存在于了

payload = self._query_context.get_payload(

cache_query_context=cache_query_context, force_cached=force_cached

)

也就是反过来调用上面转换得到的query_context中的get_payload中方法。

在get_payload方法之中

def get_payload(

self,

cache_query_context: Optional[bool] = False,

force_cached: bool = False,

) -> Dict[str, Any]:

“””Returns the query results with both metadata and data”””
return self._processor.get_payload(cache_query_context, force_cached)

则是调用processor处理器中的get_payload

这又是一个障眼法,processor会在初始化的时候传入query_context,相当于主动权交给了processor。

而在get_payload之中,则是对于每个query_context中queries数组下的每一个query进行了查询。

query_results = [

get_query_results(

query_obj.result_type or self._query_context.result_type,

self._query_context,

query_obj,

force_cached,

)

for query_obj in self._query_context.queries

]

而在get_query_results函数之中,则是根据query中的result_type调用了不同函数,比如sample进行采样操作,query则只是拼接sql,result则是拼接sql并进行查询。

这里我们以result对应的函数 get_result为线索继续往下看。_get_result 调用了_get_full函数,而在get_full函数中

def _get_full(

query_context: QueryContext,

query_obj: QueryObject,

force_cached: Optional[bool] = False,

) -> Dict[str, Any]:

datasource = _get_datasource(query_context, query_obj)

result_type = query_obj.result_type or query_context.result_type

payload = query_context.get_df_payload(query_obj, force_cached=force_cached)

applied_template_filters = payload.get(“applied_template_filters”, [])

df = payload[“df”]

status = payload[“status”]

if status != QueryStatus.FAILED:

payload[“colnames”] = list(df.columns)

payload[“indexnames”] = list(df.index)

payload[“coltypes”] = extract_dataframe_dtypes(df, datasource)

payload[“data”] = query_context.get_data(df)

payload[“result_format”] = query_context.result_format

del payload[“df”]

filters = query_obj.filter

filter_columns = cast(List[str], [flt.get(“col”) for flt in filters])

columns = set(datasource.column_names)

applied_time_columns, rejected_time_columns = get_time_filter_status(

datasource, query_obj.applied_time_extras

)

payload[“applied_filters”] = [

{“column”: get_column_name(col)}

for col in filter_columns

if is_adhoc_column(col) or col in columns or col in applied_template_filters

] + applied_time_columns

payload[“rejected_filters”] = [

{“reason”: ExtraFiltersReasonType.COL_NOT_IN_DATASOURCE, “column”: col}

for col in filter_columns

if not is_adhoc_column(col)

and col not in columns

and col not in applied_template_filters

] + rejected_time_columns

if result_type == ChartDataResultType.RESULTS and status != QueryStatus.FAILED:

return {

“data”: payload.get(“data”),

“colnames”: payload.get(“colnames”),

“coltypes”: payload.get(“coltypes”),

}

return payload

在上面这段代码之中,首先是利用query_context中获取到了dataframe.

payload = query_context.get_df_payload(query_obj, force_cached=force_cached)

然后是进行后缀处理。

比如获取字段名字,索引名字,字段类型,以及利用get_data函数将dataframe转换为不同的格式。

if status != QueryStatus.FAILED:

payload[“colnames”] = list(df.columns)

payload[“indexnames”] = list(df.index)

payload[“coltypes”] = extract_dataframe_dtypes(df, datasource)

payload[“data”] = query_context.get_data(df)

payload[“result_format”] = query_context.result_format

del payload[“df”]

最后则是一些filter拼接后返回。

那么就核心来看,还是存在于get_df_payload函数之中。

def get_df_payload(

self, query_obj: QueryObject, force_cached: Optional[bool] = False

) -> Dict[str, Any]:

“””Handles caching around the df payload retrieval”””
cache_key = self.query_cache_key(query_obj)

cache = QueryCacheManager.get(

cache_key,

CacheRegion.DATA,

self._query_context.force,

force_cached,

)

if query_obj and cache_key and not cache.is_loaded:

try:

invalid_columns = [

col

for col in get_column_names_from_columns(query_obj.columns)

+ get_column_names_from_metrics(query_obj.metrics or [])

if (

col not in self._qc_datasource.column_names

and col != DTTM_ALIAS

)

]

if invalid_columns:

raise QueryObjectValidationError(

_(

“Columns missing in dataset: %(invalid_columns)s”,

invalid_columns=invalid_columns,

)

)

query_result = self.get_query_result(query_obj)

annotation_data = self.get_annotation_data(query_obj)

cache.set_query_result(

key=cache_key,

query_result=query_result,

annotation_data=annotation_data,

force_query=self._query_context.force,

timeout=self.get_cache_timeout(),

datasource_uid=self._qc_datasource.uid,

region=CacheRegion.DATA,

)

except QueryObjectValidationError as ex:

cache.error_message = str(ex)

cache.status = QueryStatus.FAILED

# the N-dimensional DataFrame has converteds into flat DataFrame

# by `flatten operator`, “comma” in the column is escaped by `escape_separator`

# the result DataFrame columns should be unescaped

label_map = {

unescape_separator(col): [

unescape_separator(col) for col in re.split(r”(?<!\\),\s”, col)

]

for col in cache.df.columns.values

}

cache.df.columns = [unescape_separator(col) for col in cache.df.columns.values]

return {

“cache_key”: cache_key,

“cached_dttm”: cache.cache_dttm,

“cache_timeout”: self.get_cache_timeout(),

“df”: cache.df,

“applied_template_filters”: cache.applied_template_filters,

“annotation_data”: cache.annotation_data,

“error”: cache.error_message,

“is_cached”: cache.is_cached,

“query”: cache.query,

“status”: cache.status,

“stacktrace”: cache.stacktrace,

“rowcount”: len(cache.df.index),

“from_dttm”: query_obj.from_dttm,

“to_dttm”: query_obj.to_dttm,

“label_map”: label_map,

}

这个函数中,会首先尝试查询缓存中是否存在

如果获取不到,在调用query_result = self.get_query_result(query_obj)进行查询

查询完成之后存入缓存之中。

而从self.get_query_result函数进去。

核心的代码就几行,

query = “”

if isinstance(query_context.datasource, Query):

# todo(hugh): add logic to manage all sip68 models here
result = query_context.datasource.exc_query(query_object.to_dict())

else:

result = query_context.datasource.query(query_object.to_dict())

query = result.query + “;\n\n”

df = result.df

主要逻辑还是利用datasource进行的查询。

而在datasource的相关代码中,则是先进行了sql拼接

query_str_ext = self.get_query_str_extended(query_obj)

在之后执行sql获取到了dataframe。

df = self.database.get_df(sql, self.schema, mutator=assign_column_label)

而在拼接sql的时候,则是先拼接为SqlaQuery,再利用对应的db_engine进行转换。

def get_query_str_extended(self, query_obj: QueryObjectDict) -> QueryStringExtended:

sqlaq = self.get_sqla_query(**query_obj)

sql = self.database.compile_sqla_query(sqlaq.sqla_query)

sql = self._apply_cte(sql, sqlaq.cte)

sql = sqlparse.format(sql, reindent=True)

sql = self.mutate_query_from_config(sql)

return QueryStringExtended(

applied_template_filters=sqlaq.applied_template_filters,

labels_expected=sqlaq.labels_expected,

prequeries=sqlaq.prequeries,

sql=sql,

)

而主要的生成SqlaQuery的函数,就是相关的核心代码了。

这里我们进行拆解。

在函数之中,首先是初始化相关参数

extras = extras or {}

time_grain = extras.get(“time_grain_sqla”)

template_kwargs = {

“columns”: columns,

“from_dttm”: from_dttm.isoformat() if from_dttm else None,

“groupby”: groupby,

“metrics”: metrics,

“row_limit”: row_limit,

“row_offset”: row_offset,

“time_column”: granularity,

“time_grain”: time_grain,

“to_dttm”: to_dttm.isoformat() if to_dttm else None,

“table_columns”: [col.column_name for col in self.columns],

“filter”: filter,

}

columns = columns or []

groupby = groupby or []

series_column_names = utils.get_column_names(series_columns or [])

# deprecated, to be removed in 2.0

if is_timeseries and timeseries_limit:

series_limit = timeseries_limit

series_limit_metric = series_limit_metric or timeseries_limit_metric

template_kwargs.update(self.template_params_dict)

extra_cache_keys: List[Any] = []

template_kwargs[“extra_cache_keys”] = extra_cache_keys

removed_filters: List[str] = []

applied_template_filters: List[str] = []

template_kwargs[“removed_filters”] = removed_filters

template_kwargs[“applied_filters”] = applied_template_filters

template_processor = self.get_template_processor(**template_kwargs)

db_engine_spec = self.db_engine_spec

prequeries: List[str] = []

orderby = orderby or []

need_groupby = bool(metrics is not None or groupby)

metrics = metrics or []

# For backward compatibility

if granularity not in self.dttm_cols and granularity is not None:

granularity = self.main_dttm_col

columns_by_name: Dict[str, TableColumn] = {

col.column_name: col for col in self.columns

}

metrics_by_name: Dict[str, SqlMetric] = {m.metric_name: m for m in self.metrics}

这里的重点是获取到了当前查询使用到的Dataset中的column和metric。

也就是最后两行。

然后是判断是否存在要求是时间类型的图但是没有时间字段,没有metric并且没有columns和groupby。

if not granularity and is_timeseries:

raise QueryObjectValidationError(

_(

“Datetime column not provided as part table configuration ”

“and is required by this type of chart”

)

)

if not metrics and not columns and not groupby:

raise QueryObjectValidationError(_(“Empty query?”))

然后是遍历传入的metric进行相关处理

for metric in metrics:

if utils.is_adhoc_metric(metric):

assert isinstance(metric, dict)

metrics_exprs.append(

self.adhoc_metric_to_sqla(

metric=metric,

columns_by_name=columns_by_name,

template_processor=template_processor,

)

)

elif isinstance(metric, str) and metric in metrics_by_name:

metrics_exprs.append(

metrics_by_name[metric].get_sqla_col(

template_processor=template_processor

)

)

else:

raise QueryObjectValidationError(

_(“Metric ‘%(metric)s’ does not exist”, metric=metric)

)

其中主要是adhoc_metric_to_sqla,由于metric支持传入dataset中本身的metric对象和custom sql。这里进行了判断,并进行了不同的处理

其中从传入的columns_by_name 字典中获取到字段对象

然后最重要的是利用self.sqla_aggregation字段获取获取到对应的聚合函数,处理对应字段,最后利用对应的db_engine进行编译即可。

def adhoc_metric_to_sqla(

self,

metric: AdhocMetric,

columns_by_name: Dict[str, TableColumn],

template_processor: Optional[BaseTemplateProcessor] = None,

) -> ColumnElement:

“””
Turn an adhoc metric into a sqlalchemy column.

:param dict metric: Adhoc metric definition
:param dict columns_by_name: Columns for the current table
:param template_processor: template_processor instance
:returns: The metric defined as a sqlalchemy column
:rtype: sqlalchemy.sql.column
“””
expression_type = metric.get(“expressionType”)

label = utils.get_metric_name(metric)

if expression_type == utils.AdhocMetricExpressionType.SIMPLE:

metric_column = metric.get(“column”) or {}

column_name = cast(str, metric_column.get(“column_name”))

table_column: Optional[TableColumn] = columns_by_name.get(column_name)

if table_column:

sqla_column = table_column.get_sqla_col(

template_processor=template_processor

)

else:

sqla_column = column(column_name)

sqla_metric = self.sqla_aggregations[metric[“aggregate”]](sqla_column)

elif expression_type == utils.AdhocMetricExpressionType.SQL:

expression = _process_sql_expression(

expression=metric[“sqlExpression”],

database_id=self.database_id,

schema=self.schema,

template_processor=template_processor,

)

sqla_metric = literal_column(expression)

else:

raise QueryObjectValidationError(“Adhoc metric expressionType is invalid”)

return self.make_sqla_column_compatible(sqla_metric, label)

这里我们直接给出对应的sqla_aggregations字段

sqla_aggregations = {

“COUNT_DISTINCT”: lambda column_name: sa.func.COUNT(sa.distinct(column_name)),

“COUNT”: sa.func.COUNT,

“SUM”: sa.func.SUM,

“AVG”: sa.func.AVG,

“MIN”: sa.func.MIN,

“MAX”: sa.func.MAX,

}

其中是直接利用了sqlachemy相关的func实现的。

如果是custom sql 则是直接parse sql之后封装为literal_column(expression)

这里循环处理完成metric之后,接下来进行orderby的处理

遍历orderby字段过程中,主要逻辑就是根据是否是聚合函数,来判断是否需要group by

在处理完成之后,获取到了need_group 字段值之后,根据是否是需要聚合,来使用groups 字段或者columns 字段来拼接select_exprs数组以及groupby_all_columns字典。

然后根据是否是时序图

来添加时间字段到select_exprs数组之中。

if is_timeseries:

timestamp = dttm_col.get_timestamp_expression(

time_grain=time_grain, template_processor=template_processor

)

# always put timestamp as the first column

select_exprs.insert(0, timestamp)

groupby_all_columns[timestamp.name] = timestamp

以及根据是否具有时间字段,来在time_filters添加过滤器。

time_filters.append(

dttm_col.get_time_filter(

start_dttm=from_dttm,

end_dttm=to_dttm,

template_processor=template_processor,

)

)

在接下来去重所选字段之后,进行实际sqlaquery的拼接。

qry = sa.select(select_exprs)

首先是根据选择字段创建qry对象。

填充group by 字段

if groupby_all_columns:

qry = qry.group_by(*groupby_all_columns.values())

再其次是填充where和having字段

对于where字段,则是从入参中的filters数组进行取出拼接

这个数组每一个下标对应着一个字典,包含了col val op三个key

对于每一个数组,都是获取到字段sql,根据op,进行相关的sql拼接

assert isinstance(eq, (tuple, list))

if len(eq) == 0:

raise QueryObjectValidationError(

_(“Filter value list cannot be empty”)

)

if len(eq) > len(

eq_without_none := [x for x in eq if x is not None]

):

is_null_cond = sqla_col.is_(None)

if eq:

cond = or_(is_null_cond, sqla_col.in_(eq_without_none))

else:

cond = is_null_cond

else:

cond = sqla_col.in_(eq)

if op == utils.FilterOperator.NOT_IN.value:

cond = ~cond

where_clause_and.append(cond)

elif op == utils.FilterOperator.IS_NULL.value:

where_clause_and.append(sqla_col.is_(None))

elif op == utils.FilterOperator.IS_NOT_NULL.value:

where_clause_and.append(sqla_col.isnot(None))

elif op == utils.FilterOperator.IS_TRUE.value:

where_clause_and.append(sqla_col.is_(True))

elif op == utils.FilterOperator.IS_FALSE.value:

where_clause_and.append(sqla_col.is_(False))

else:

if (

op

not in {

utils.FilterOperator.EQUALS.value,

utils.FilterOperator.NOT_EQUALS.value,

}

and eq is None

):

raise QueryObjectValidationError(

_(

“Must specify a value for filters ”

“with comparison operators”

)

)

if op == utils.FilterOperator.EQUALS.value:

where_clause_and.append(sqla_col == eq)

elif op == utils.FilterOperator.NOT_EQUALS.value:

where_clause_and.append(sqla_col != eq)

elif op == utils.FilterOperator.GREATER_THAN.value:

where_clause_and.append(sqla_col > eq)

elif op == utils.FilterOperator.LESS_THAN.value:

where_clause_and.append(sqla_col < eq)

elif op == utils.FilterOperator.GREATER_THAN_OR_EQUALS.value:

where_clause_and.append(sqla_col >= eq)

elif op == utils.FilterOperator.LESS_THAN_OR_EQUALS.value:

where_clause_and.append(sqla_col <= eq)

elif op == utils.FilterOperator.LIKE.value:

where_clause_and.append(sqla_col.like(eq))

elif op == utils.FilterOperator.ILIKE.value:

where_clause_and.append(sqla_col.ilike(eq))

elif (

op == utils.FilterOperator.TEMPORAL_RANGE.value

and isinstance(eq, str)

and col_obj is not None

):

_since, _until = get_since_until_from_time_range(

time_range=eq,

time_shift=time_shift,

extras=extras,

)

where_clause_and.append(

col_obj.get_time_filter(

start_dttm=_since,

end_dttm=_until,

label=sqla_col.key,

template_processor=template_processor,

)

)

else:

raise QueryObjectValidationError(

_(“Invalid filter operation type: %(op)s”, op=op)

)

在根据filters配置完成where之后,

如果用户书写了custom sql,那么会作为extra字段保存在入参之中

并且在custom sql之中存在着where和having的选项,

那么我们就要对extra字典中的where和having进行解析,分别拼接在where_clause_and数组和having_clause_and数组之中。

if extras:

where = extras.get(“where”)

if where:

try:

where = template_processor.process_template(f”({where})”)

except TemplateError as ex:

raise QueryObjectValidationError(

_(

“Error in jinja expression in WHERE clause: %(msg)s”,

msg=ex.message,

)

) from ex

where = _process_sql_expression(

expression=where,

database_id=self.database_id,

schema=self.schema,

)

where_clause_and += [self.text(where)]

having = extras.get(“having”)

if having:

try:

having = template_processor.process_template(f”({having})”)

except TemplateError as ex:

raise QueryObjectValidationError(

_(

“Error in jinja expression in HAVING clause: %(msg)s”,

msg=ex.message,

)

) from ex

having = _process_sql_expression(

expression=having,

database_id=self.database_id,

schema=self.schema,

)

having_clause_and += [self.text(having)]

之后给qry添加上where 和 having条件

if granularity:

qry = qry.where(and_(*(time_filters + where_clause_and)))

else:

qry = qry.where(and_(*where_clause_and))

qry = qry.having(and_(*having_clause_and))

之后就是利用上面拼接好的order by 进行order by的拼接。

以及 limit 和offset 的拼接

if row_limit:

qry = qry.limit(row_limit)

if row_offset:

qry = qry.offset(row_offset)

并且根据dataset相关的定义,拼接上from

tbl, cte = self.get_from_clause(template_processor)

qry = qry.select_from(tbl)

这样就将这样的一个qry拼接完成,返回给上一级,利用db_engine 编译出sql

并利用sql获取到dataframe了,这样就是echarts相关查询数据全过程。

在下半部分,我们将讲述针对superset如何针对其他类型的图,书写出一个查询数据框架的。

发表评论

邮箱地址不会被公开。 必填项已用*标注