Files
signoz/pkg/querybuilder/having_expression_validator.go
Tushar Vats f71d5bf8f1 fix: added validations for having expression (#10286)
* fix: added validations for having expression

* fix: added extra validation and unit tests

* fix: added antlr based parsing for validation

* fix: added more unit tests

* fix: removed validation on having in range request validations

* fix: generated lexer files and added more unit tests

* fix: edge cases

* fix: added cmnd to scripts for generating lexer

* fix: use std libg sorting instead of selection sort

* fix: support implicit and

* fix: allow bare not in expression

* fix: added suggestion for having expression

* fix: typo

* fix: added more unit tests, handle white space difference in aggregation exp and having exp

* fix: added support for in and not, updated errors

* fix: added support for brackets list

* fix: lint error

* fix: handle non spaced expression

---------

Co-authored-by: Srikanth Chekuri <srikanth.chekuri92@gmail.com>
2026-04-02 03:52:11 +00:00

497 lines
17 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package querybuilder
import (
"sort"
"strings"
"github.com/SigNoz/signoz/pkg/errors"
grammar "github.com/SigNoz/signoz/pkg/parser/havingexpression/grammar"
"github.com/antlr4-go/antlr/v4"
"github.com/huandu/go-sqlbuilder"
)
// havingExpressionRewriteVisitor walks the parse tree of a HavingExpression in a single
// pass, simultaneously rewriting user-facing references to their SQL column names and
// collecting any references that could not be resolved.
//
// Each visit method reconstructs the expression string for its subtree:
// - Structural nodes (orExpression, andExpression, comparison, arithmetic) are
// reconstructed with canonical spacing.
// - andExpression joins ALL primaries with " AND ", which naturally normalises any
// implicit-AND adjacency (the old normalizeImplicitAND step).
// - IdentifierContext looks the name up in columnMap; if found the SQL column name is
// returned. If the name is already a valid SQL column (TO side of columnMap) it is
// passed through unchanged. Otherwise it is added to invalid.
// - FunctionCallContext looks the full call text (without whitespace, since WS is
// skipped) up in columnMap; if found the SQL column name is returned, otherwise the
// function name is added to invalid without recursing into its arguments.
// The grammar now accepts complex function arguments (nested calls, string predicates),
// so all aggregation expression forms can be looked up directly via ctx.GetText().
// - STRING atoms (string literals in comparison position) set hasStringLiteral so a
// friendly "aggregator results are numeric" error can be returned.
type havingExpressionRewriteVisitor struct {
columnMap map[string]string
validColumns map[string]bool // TO-side values; identifiers already in SQL form pass through
invalid []string
seen map[string]bool
hasStringLiteral bool
sb *sqlbuilder.SelectBuilder
}
func newHavingExpressionRewriteVisitor(columnMap map[string]string) *havingExpressionRewriteVisitor {
validColumns := make(map[string]bool, len(columnMap))
for _, col := range columnMap {
validColumns[col] = true
}
return &havingExpressionRewriteVisitor{
columnMap: columnMap,
validColumns: validColumns,
seen: make(map[string]bool),
sb: sqlbuilder.NewSelectBuilder(),
}
}
func (v *havingExpressionRewriteVisitor) visitQuery(ctx grammar.IQueryContext) string {
if ctx.Expression() == nil {
return ""
}
return v.visitExpression(ctx.Expression())
}
func (v *havingExpressionRewriteVisitor) visitExpression(ctx grammar.IExpressionContext) string {
return v.visitOrExpression(ctx.OrExpression())
}
func (v *havingExpressionRewriteVisitor) visitOrExpression(ctx grammar.IOrExpressionContext) string {
andExprs := ctx.AllAndExpression()
parts := make([]string, len(andExprs))
for i, ae := range andExprs {
parts[i] = v.visitAndExpression(ae)
}
if len(parts) == 1 {
return parts[0]
}
return v.sb.Or(parts...)
}
// visitAndExpression joins ALL primaries with " AND ".
// The grammar rule `primary ( AND primary | primary )*` allows adjacent primaries
// without an explicit AND (implicit AND). Joining all of them with " AND " here is
// equivalent to the old normalizeImplicitAND step.
func (v *havingExpressionRewriteVisitor) visitAndExpression(ctx grammar.IAndExpressionContext) string {
primaries := ctx.AllPrimary()
parts := make([]string, len(primaries))
for i, p := range primaries {
parts[i] = v.visitPrimary(p)
}
if len(parts) == 1 {
return parts[0]
}
return v.sb.And(parts...)
}
func (v *havingExpressionRewriteVisitor) visitPrimary(ctx grammar.IPrimaryContext) string {
if ctx.OrExpression() != nil {
inner := v.visitOrExpression(ctx.OrExpression())
if ctx.NOT() != nil {
return v.sb.Not(inner)
}
return v.sb.And(inner)
}
if ctx.Comparison() == nil {
return ""
}
inner := v.visitComparison(ctx.Comparison())
if ctx.NOT() != nil {
return v.sb.Not(inner)
}
return inner
}
func (v *havingExpressionRewriteVisitor) visitComparison(ctx grammar.IComparisonContext) string {
if ctx.IN() != nil {
if ctx.Operand(0) == nil || ctx.InList() == nil {
return ""
}
lhs := v.visitOperand(ctx.Operand(0))
signedNumbers := ctx.InList().AllSignedNumber()
vals := make([]interface{}, len(signedNumbers))
for i, n := range signedNumbers {
vals[i] = sqlbuilder.Raw(n.GetText())
}
if ctx.NOT() != nil {
// Here we need to compile because In generates lhs IN $1 syntax
sql, _ := v.sb.Args.CompileWithFlavor(v.sb.NotIn(lhs, vals...), sqlbuilder.ClickHouse)
return sql
}
// Here we need to compile because In generates lhs IN $1 syntax
sql, _ := v.sb.Args.CompileWithFlavor(v.sb.In(lhs, vals...), sqlbuilder.ClickHouse)
return sql
}
if ctx.CompOp() == nil || ctx.Operand(0) == nil || ctx.Operand(1) == nil {
return ""
}
lhs := v.visitOperand(ctx.Operand(0))
op := ctx.CompOp().GetText()
rhs := v.visitOperand(ctx.Operand(1))
return lhs + " " + op + " " + rhs
}
func (v *havingExpressionRewriteVisitor) visitOperand(ctx grammar.IOperandContext) string {
if ctx.Operand() != nil {
left := v.visitOperand(ctx.Operand())
right := v.visitTerm(ctx.Term())
op := "+"
if ctx.MINUS() != nil {
op = "-"
}
return left + " " + op + " " + right
}
return v.visitTerm(ctx.Term())
}
func (v *havingExpressionRewriteVisitor) visitTerm(ctx grammar.ITermContext) string {
if ctx.Term() != nil {
left := v.visitTerm(ctx.Term())
right := v.visitFactor(ctx.Factor())
op := "*"
if ctx.SLASH() != nil {
op = "/"
} else if ctx.PERCENT() != nil {
op = "%"
}
return left + " " + op + " " + right
}
return v.visitFactor(ctx.Factor())
}
func (v *havingExpressionRewriteVisitor) visitFactor(ctx grammar.IFactorContext) string {
if ctx.Factor() != nil {
// Unary sign: (PLUS | MINUS) factor
sign := "+"
if ctx.MINUS() != nil {
sign = "-"
}
return sign + v.visitFactor(ctx.Factor())
}
if ctx.Operand() != nil {
return v.sb.And(v.visitOperand(ctx.Operand()))
}
if ctx.Atom() == nil {
return ""
}
return v.visitAtom(ctx.Atom())
}
func (v *havingExpressionRewriteVisitor) visitAtom(ctx grammar.IAtomContext) string {
if ctx.FunctionCall() != nil {
return v.visitFunctionCall(ctx.FunctionCall())
}
if ctx.Identifier() != nil {
return v.visitIdentifier(ctx.Identifier())
}
if ctx.STRING() != nil {
// String literals are never valid aggregation results; flag for a friendly error.
v.hasStringLiteral = true
return ctx.STRING().GetText()
}
text := ctx.NUMBER().GetText()
return text
}
// visitFunctionCall looks the full call text up in columnMap. WS tokens are skipped by
// the lexer, so ctx.GetText() returns the expression with all whitespace removed
// (e.g. "countIf(level='error')", "avg(sum(cpu_usage))", "count_distinct(a,b)").
// The column map stores both the original expression and a space-stripped version as
// keys, so the lookup is whitespace-insensitive regardless of how the user typed it.
// If not found, the function name is recorded as invalid.
func (v *havingExpressionRewriteVisitor) visitFunctionCall(ctx grammar.IFunctionCallContext) string {
fullText := ctx.GetText()
if col, ok := v.columnMap[fullText]; ok {
return col
}
funcName := ctx.IDENTIFIER().GetText()
if !v.seen[funcName] {
v.invalid = append(v.invalid, funcName)
v.seen[funcName] = true
}
return fullText
}
// visitIdentifier looks the identifier up in columnMap. If found, returns the SQL
// column name. If the name is already a valid SQL column (validColumns), it is passed
// through unchanged — this handles cases where the user writes the SQL column name
// directly (e.g. __result_0). Otherwise records it as invalid.
func (v *havingExpressionRewriteVisitor) visitIdentifier(ctx grammar.IIdentifierContext) string {
name := ctx.IDENTIFIER().GetText()
if col, ok := v.columnMap[name]; ok {
return col
}
if v.validColumns[name] {
return name
}
if !v.seen[name] {
v.invalid = append(v.invalid, name)
v.seen[name] = true
}
return name
}
// rewriteAndValidate is the single-pass implementation used by all RewriteFor* methods.
//
// Validation layers:
// 1. The visitor runs on the parse tree, rewriting and collecting invalid references.
// Unknown references (including unrecognised function calls) → lists valid references.
// The grammar now supports complex function arguments (nested calls, string predicates)
// so all aggregation expression forms are handled directly by the parser without any
// regex pre-substitution.
// 2. String literals in comparison-operand position → descriptive error
// ("aggregator results are numeric").
// 3. ANTLR syntax errors → error with messages referencing the original token names.
func (r *HavingExpressionRewriter) rewriteAndValidate(expression string) (string, error) {
original := strings.TrimSpace(expression)
// Parse the expression once.
input := antlr.NewInputStream(expression)
lexer := grammar.NewHavingExpressionLexer(input)
lexerErrListener := NewErrorListener()
lexer.RemoveErrorListeners()
lexer.AddErrorListener(lexerErrListener)
tokens := antlr.NewCommonTokenStream(lexer, antlr.TokenDefaultChannel)
p := grammar.NewHavingExpressionParser(tokens)
parserErrListener := NewErrorListener()
p.RemoveErrorListeners()
p.AddErrorListener(parserErrListener)
tree := p.Query()
// Layer 1 run the combined visitor and report any unresolved references.
// This runs before the syntax error check so that expressions with recoverable
// parse errors (e.g. sum(count())) still produce an actionable "invalid reference"
// message rather than a raw syntax error.
v := newHavingExpressionRewriteVisitor(r.columnMap)
result := v.visitQuery(tree)
// Layer 2 string literals in comparison-operand position (atom rule).
// The grammar accepts STRING tokens in atom so the parser can recover and continue,
// but the visitor flags them; aggregator results are always numeric.
// This is checked before invalid references so that "contains string literals" takes
// priority when a bare string literal is also an unresolvable operand.
if v.hasStringLiteral {
return "", errors.NewInvalidInputf(
errors.CodeInvalidInput,
"`Having` expression contains string literals",
).WithAdditional("Aggregator results are numeric")
}
if len(v.invalid) > 0 {
sort.Strings(v.invalid)
validKeys := make([]string, 0, len(r.columnMap))
for k := range r.columnMap {
validKeys = append(validKeys, k)
}
sort.Strings(validKeys)
additional := []string{"Valid references are: [" + strings.Join(validKeys, ", ") + "]"}
if len(v.invalid) == 1 {
inv := v.invalid[0]
// Only suggest for plain identifier typos, not for unresolved function
// calls: a function call will appear as "name(" in the expression, and
// the closest valid key may itself contain "(" (e.g. "sum(a)"), making
// a simple string substitution produce a corrupt expression.
isFuncCall := strings.Contains(original, inv+"(")
if match, dist := closestMatch(inv, validKeys); !isFuncCall && !strings.Contains(match, "(") && dist <= 3 {
corrected := strings.ReplaceAll(original, inv, match)
additional = append(additional, "Suggestion: `"+corrected+"`")
}
}
return "", errors.NewInvalidInputf(
errors.CodeInvalidInput,
"Invalid references in `Having` expression: [%s]",
strings.Join(v.invalid, ", "),
).WithAdditional(additional...)
}
// Layer 3 ANTLR syntax errors. We parse the original expression, so error messages
// already reference the user's own token names; no re-parsing is needed.
allSyntaxErrors := append(lexerErrListener.SyntaxErrors, parserErrListener.SyntaxErrors...)
if len(allSyntaxErrors) > 0 {
msgs := make([]string, 0, len(allSyntaxErrors))
for _, se := range allSyntaxErrors {
if m := se.Error(); m != "" {
msgs = append(msgs, m)
}
}
detail := strings.Join(msgs, "; ")
if detail == "" {
detail = "check the expression syntax"
}
additional := []string{detail}
// For single-error expressions, try to produce an actionable suggestion.
if len(allSyntaxErrors) == 1 {
if s := havingSuggestion(allSyntaxErrors[0], original); s != "" {
additional = append(additional, "Suggestion: `"+s+"`")
}
}
return "", errors.NewInvalidInputf(
errors.CodeInvalidInput,
"Syntax error in `Having` expression",
).WithAdditional(additional...)
}
return result, nil
}
// havingSuggestion returns a corrected expression string to show as a suggestion when
// the error matches a well-known single-mistake pattern, or "" when no suggestion
// can be formed. Only call this when there is exactly one syntax error.
//
// Recognised patterns (all produce a minimal, valid completion):
// 1. Bare aggregation — comparison operator expected at EOF: count() → count() > 0
// 2. Missing right operand after comparison op at EOF: count() > → count() > 0
// 3. Unclosed parenthesis — only ) expected at EOF: (total > 100 → (total > 100)
// 4. Dangling AND/OR at end of expression: total > 100 AND → total > 100
// 5. Leading OR at position 0: OR total > 100 → total > 100
func havingSuggestion(se *SyntaxErr, original string) string {
trimmed := strings.TrimSpace(original)
upper := strings.ToUpper(trimmed)
if se.TokenTxt == "EOF" {
// Pattern 4: dangling AND or OR at end of expression.
// e.g. total > 100 AND → total > 100
// Checked before Pattern 1 so that "expr AND" does not match Pattern 1.
if strings.HasSuffix(upper, " AND") {
return strings.TrimSpace(trimmed[:len(trimmed)-4])
}
if strings.HasSuffix(upper, " OR") {
return strings.TrimSpace(trimmed[:len(trimmed)-3])
}
// Pattern 1: bare aggregation reference — no comparison operator yet.
// Detected by: IDENTIFIER in expected (operand-continuation set), expression
// does not already end with a comparison operator (Pattern 2 handles that case),
// and no unclosed parenthesis (Pattern 3 handles that case).
// e.g. count() → count() > 0
// total_logs → total_logs > 0
if expectedContains(se, "IDENTIFIER") && !endsWithComparisonOp(trimmed) && !hasUnclosedParen(trimmed) {
return trimmed + " > 0"
}
// Pattern 2: comparison operator already written but right operand missing.
// e.g. count() > → count() > 0
if expectedContains(se, "number") && endsWithComparisonOp(trimmed) {
return trimmed + " 0"
}
// Pattern 3: unclosed parenthesis with content inside.
// e.g. (total > 100 AND count() < 500 → (total > 100 AND count() < 500)
// Guard len > 1 avoids a useless "()" suggestion for a bare "(".
if expectedContains(se, ")") && hasUnclosedParen(trimmed) && len(trimmed) > 1 {
return trimmed + ")"
}
// Pattern 6: unclosed IN bracket list.
// e.g. count() IN [1, 2, 3 → count() IN [1, 2, 3]
if expectedContains(se, "]") && hasUnclosedBracket(trimmed) && len(trimmed) > 1 {
return trimmed + "]"
}
return ""
}
// Pattern 5: leading OR at position 0.
// e.g. OR total > 100 → total > 100
if se.TokenTxt == "'OR'" && se.Col == 0 && strings.HasPrefix(upper, "OR ") {
return strings.TrimSpace(trimmed[3:])
}
return ""
}
// expectedContains reports whether label is present in se.Expected.
func expectedContains(se *SyntaxErr, label string) bool {
for _, e := range se.Expected {
if e == label {
return true
}
}
return false
}
// hasUnclosedParen reports whether s contains more '(' than ')'.
func hasUnclosedParen(s string) bool {
count := 0
for _, c := range s {
switch c {
case '(':
count++
case ')':
count--
}
}
return count > 0
}
// hasUnclosedBracket reports whether s contains more '[' than ']'.
func hasUnclosedBracket(s string) bool {
count := 0
for _, c := range s {
switch c {
case '[':
count++
case ']':
count--
}
}
return count > 0
}
// closestMatch returns the element of candidates with the smallest Levenshtein
// distance to query, along with that distance.
func closestMatch(query string, candidates []string) (string, int) {
best, bestDist := "", -1
for _, c := range candidates {
if d := levenshtein(query, c); bestDist < 0 || d < bestDist {
best, bestDist = c, d
}
}
return best, bestDist
}
// levenshtein computes the edit distance between a and b.
func levenshtein(a, b string) int {
ra, rb := []rune(a), []rune(b)
la, lb := len(ra), len(rb)
row := make([]int, lb+1)
for j := range row {
row[j] = j
}
for i := 1; i <= la; i++ {
prev := row[0]
row[0] = i
for j := 1; j <= lb; j++ {
tmp := row[j]
if ra[i-1] == rb[j-1] {
row[j] = prev
} else {
row[j] = 1 + min(prev, min(row[j], row[j-1]))
}
prev = tmp
}
}
return row[lb]
}
// endsWithComparisonOp reports whether s ends with a comparison operator token
// (longer operators are checked first to avoid ">=" being matched by ">").
func endsWithComparisonOp(s string) bool {
for _, op := range []string{">=", "<=", "!=", "<>", "==", ">", "<", "="} {
if strings.HasSuffix(s, op) {
return true
}
}
return false
}