diff --git a/pkg/proxy/driver/queryctx.go b/pkg/proxy/driver/queryctx.go index f0c9f53..30ada20 100644 --- a/pkg/proxy/driver/queryctx.go +++ b/pkg/proxy/driver/queryctx.go @@ -114,7 +114,7 @@ func (q *QueryCtxImpl) Execute(ctx context.Context, sql string) (*gomysql.Result tableName := wast.ExtractFirstTableNameFromStmt(stmt) ctx = wast.CtxWithAstTableName(ctx, tableName) - sqlParadigm, err := q.extractSqlParadigm(ctx, sql) + sqlParadigm, err := extractStmtParadigm(stmt) if err != nil { return nil, err } diff --git a/pkg/proxy/driver/queryctx_exec.go b/pkg/proxy/driver/queryctx_exec.go index 3c97a9f..0b3a2dc 100644 --- a/pkg/proxy/driver/queryctx_exec.go +++ b/pkg/proxy/driver/queryctx_exec.go @@ -5,14 +5,14 @@ import ( "hash/crc32" "strings" - "github.com/tidb-incubator/weir/pkg/proxy/constant" - wast "github.com/tidb-incubator/weir/pkg/util/ast" "github.com/pingcap/errors" "github.com/pingcap/parser/ast" "github.com/pingcap/parser/mysql" "github.com/pingcap/tidb/sessionctx/variable" "github.com/pingcap/tidb/util/logutil" gomysql "github.com/siddontang/go-mysql/mysql" + "github.com/tidb-incubator/weir/pkg/proxy/constant" + wast "github.com/tidb-incubator/weir/pkg/util/ast" "go.uber.org/zap" ) @@ -60,17 +60,21 @@ func (q *QueryCtxImpl) getRateLimiterKey(ctx context.Context, rateLimiter RateLi } } -func (q *QueryCtxImpl) extractSqlParadigm(ctx context.Context, sql string) (string, error) { - charsetInfo, collation := q.sessionVars.GetCharsetInfo() - featureStmt, err := q.parser.ParseOneStmt(sql, charsetInfo, collation) +func extractStmtParadigm(stmt ast.StmtNode) (string, error) { + visitor, err := wast.ExtractAstVisit(stmt) if err != nil { return "", err } - visitor, err := wast.ExtractAstVisit(featureStmt) + return visitor.SqlFeature(), nil +} + +func (q *QueryCtxImpl) extractSqlParadigm(ctx context.Context, sql string) (string, error) { + charsetInfo, collation := q.sessionVars.GetCharsetInfo() + featureStmt, err := q.parser.ParseOneStmt(sql, charsetInfo, collation) if err != nil { return "", err } - return visitor.SqlFeature(), nil + return extractStmtParadigm(featureStmt) } func (q *QueryCtxImpl) executeStmt(ctx context.Context, sql string, stmtNode ast.StmtNode) (*gomysql.Result, error) {