diff --git a/cypher/models/cypher/format/format_test.go b/cypher/models/cypher/format/format_test.go index 5b431d2b..327f65d4 100644 --- a/cypher/models/cypher/format/format_test.go +++ b/cypher/models/cypher/format/format_test.go @@ -4,6 +4,7 @@ import ( "bytes" "testing" + "github.com/specterops/dawgs/cypher/models/cypher" "github.com/specterops/dawgs/cypher/models/cypher/format" "github.com/specterops/dawgs/cypher/frontend" @@ -34,3 +35,125 @@ func TestCypherEmitter_HappyPath(t *testing.T) { func TestCypherEmitter_NegativeCases(t *testing.T) { test.LoadFixture(t, test.NegativeTestCases).Run(t) } + +func TestNewStringLiteral_Escaping(t *testing.T) { + testCases := []struct { + name string + input string + expected string + }{ + { + name: "backslash should be escaped", + input: `TEST\PS1-PSV$@`, + expected: `'TEST\\PS1-PSV$@'`, + }, + { + name: "single quote should be escaped", + input: `O'Brien`, + expected: `'O\'Brien'`, + }, + { + name: "both backslash and single quote", + input: `path\to\file's location`, + expected: `'path\\to\\file\'s location'`, + }, + { + name: "multiple backslashes", + input: `C:\Windows\System32`, + expected: `'C:\\Windows\\System32'`, + }, + { + name: "no special characters", + input: `simple_value`, + expected: `'simple_value'`, + }, + { + name: "empty string", + input: ``, + expected: `''`, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + literal := cypher.NewStringLiteral(tc.input) + require.NotNil(t, literal) + require.Equal(t, tc.expected, literal.Value) + }) + } +} + +func TestNewStringLiteral_InQuery(t *testing.T) { + // Test that escaped string literals work correctly in actual Cypher queries + testCases := []struct { + name string + propertyKey string + value string + expectedQuery string + }{ + { + name: "backslash in objectid", + propertyKey: "objectid", + value: `TEST\PS1-PSV$@`, + expectedQuery: `match (n {objectid: 'TEST\\PS1-PSV$@'}) return n`, + }, + { + name: "single quote in name", + propertyKey: "name", + value: `O'Brien`, + expectedQuery: `match (n {name: 'O\'Brien'}) return n`, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + // Build a query using NewStringLiteral + literal := cypher.NewStringLiteral(tc.value) + + // Create a simple query structure + query := &cypher.RegularQuery{ + SingleQuery: &cypher.SingleQuery{ + SinglePartQuery: &cypher.SinglePartQuery{ + ReadingClauses: []*cypher.ReadingClause{ + { + Match: &cypher.Match{ + Pattern: []*cypher.PatternPart{ + { + PatternElements: []*cypher.PatternElement{ + { + Element: &cypher.NodePattern{ + Variable: &cypher.Variable{Symbol: "n"}, + Properties: cypher.MapLiteral{ + tc.propertyKey: literal, + }, + }, + }, + }, + }, + }, + }, + }, + }, + Return: &cypher.Return{ + Projection: &cypher.Projection{ + Items: []cypher.Expression{ + &cypher.ProjectionItem{ + Expression: &cypher.Variable{Symbol: "n"}, + }, + }, + }, + }, + }, + }, + } + + // Format the query + buffer := &bytes.Buffer{} + emitter := format.NewCypherEmitter(false) + err := emitter.Write(query, buffer) + + require.Nil(t, err) + require.Equal(t, tc.expectedQuery, buffer.String()) + }) + } +} diff --git a/cypher/models/cypher/model.go b/cypher/models/cypher/model.go index aff1f0fd..27cdd549 100644 --- a/cypher/models/cypher/model.go +++ b/cypher/models/cypher/model.go @@ -770,11 +770,43 @@ func (s *KindMatcher) copy() *KindMatcher { } } +// Literal is a Cypher AST node representing a literal value (string, integer, +// float, boolean, list, map, or null). +// +// String values carry a representation contract that callers must respect: +// when Value holds a string, it must be in Cypher source form, meaning the +// raw token as it would appear in a query, including the surrounding quote +// characters (' or ") and with all escape sequences left un-decoded +// (\\, \', \", \b, \f, \n, \r, \t). Non-ASCII characters are carried as raw +// UTF-8 bytes inside the literal body. The parser stores string literals in +// this form so that the AST round-trips back to source verbatim via the +// format package, and downstream backends (for example the PgSQL translator) +// decode the source form into the final byte sequence before emitting their +// target language. +// +// Use NewStringLiteral to construct a string-valued Literal from an arbitrary +// Go string; it performs the required quoting and escaping. Constructing a +// Literal directly with a raw, un-quoted Go string in Value will be rejected +// at translation time. +// +// Non-string values (int, float64, bool, etc.) are stored as-is and are not +// subject to the source-form contract. Null literals set Null to true and +// leave Value unused. type Literal struct { Value any Null bool } +// NewLiteral constructs a Literal wrapping the provided value. +// +// This constructor does not transform value in any way. When passing a string, +// the caller is responsible for ensuring it is already in Cypher source form +// (surrounding quotes present, special characters escaped); see the Literal +// type documentation for the full contract. For arbitrary Go strings, prefer +// NewStringLiteral, which performs the quoting and escaping for you. +// +// For non-string values (numbers, booleans) value is stored verbatim. To +// represent a Cypher null, pass null=true. func NewLiteral(value any, null bool) *Literal { return &Literal{ Value: value, @@ -782,8 +814,23 @@ func NewLiteral(value any, null bool) *Literal { } } +// NewStringLiteral constructs a string-valued Literal from an arbitrary Go +// string, producing the Cypher source-form representation required by the +// Literal contract. +// +// The input is escaped for Cypher single-quoted string literals — every +// backslash is doubled (\ becomes \\) and every single quote is escaped +// (' becomes \') — and then wrapped in surrounding single quotes. The +// resulting Literal.Value is a Cypher source token equivalent to what the +// parser would have produced for the same input, so it round-trips correctly +// through emitters and is decoded back to the original byte sequence by +// translator backends. +// +// Example: NewStringLiteral(`TEST\PS1`) yields Value == `'TEST\\PS1'`. func NewStringLiteral(value string) *Literal { - return NewLiteral("'"+value+"'", false) + escaped := strings.ReplaceAll(value, "\\", "\\\\") + escaped = strings.ReplaceAll(escaped, "'", "\\'") + return NewLiteral("'"+escaped+"'", false) } func (s *Literal) copy() *Literal { diff --git a/cypher/models/pgsql/translate/translator.go b/cypher/models/pgsql/translate/translator.go index 3b528cad..3fbb0ec6 100644 --- a/cypher/models/pgsql/translate/translator.go +++ b/cypher/models/pgsql/translate/translator.go @@ -2,6 +2,8 @@ package translate import ( "context" + "fmt" + "strings" "github.com/specterops/dawgs/cypher/models/cypher" "github.com/specterops/dawgs/cypher/models/pgsql" @@ -116,8 +118,11 @@ func (s *Translator) Enter(expression cypher.SyntaxNode) { literalValue := typedExpression.Value if stringValue, isString := typedExpression.Value.(string); isString { - // Cypher parser wraps string literals with ' characters - literalValue = stringValue[1 : len(stringValue)-1] + if decoded, err := decodeCypherStringLiteral(stringValue); err != nil { + s.SetError(err) + } else { + literalValue = decoded + } } if newLiteral, err := pgsql.AsLiteral(literalValue); err != nil { @@ -452,3 +457,47 @@ func Translate(ctx context.Context, cypherQuery *cypher.RegularQuery, kindMapper return translator.translation, nil } + +func decodeCypherStringLiteral(raw string) (string, error) { + if len(raw) < 2 { + return "", fmt.Errorf("invalid cypher string literal: %q", raw) + } else if quote := raw[0]; (quote != '\'' && quote != '"') || raw[len(raw)-1] != quote { + return "", fmt.Errorf("invalid cypher string literal: missing or mismatched surrounding quotes: %q", raw) + } + // Cypher parser wraps string literals with ' characters + body := raw[1 : len(raw)-1] + var b strings.Builder + b.Grow(len(body)) + for i := 0; i < len(body); i++ { + if body[i] != '\\' { + b.WriteByte(body[i]) + continue + } + if i+1 >= len(body) { + return "", fmt.Errorf("dangling escape in string literal") + } + switch c := body[i+1]; c { + case '\\', '\'', '"': + b.WriteByte(c) + i++ + case 'b', 'B': + b.WriteByte('\b') + i++ + case 'f', 'F': + b.WriteByte('\f') + i++ + case 'n', 'N': + b.WriteByte('\n') + i++ + case 'r', 'R': + b.WriteByte('\r') + i++ + case 't', 'T': + b.WriteByte('\t') + i++ + default: + return "", fmt.Errorf("invalid escape \\%c", c) + } + } + return b.String(), nil +} diff --git a/cypher/models/pgsql/translate/translator_test.go b/cypher/models/pgsql/translate/translator_test.go new file mode 100644 index 00000000..8a1c7296 --- /dev/null +++ b/cypher/models/pgsql/translate/translator_test.go @@ -0,0 +1,86 @@ +package translate + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestDecodeCypherStringLiteral(t *testing.T) { + t.Parallel() + + type expected struct { + value string + errContains string + } + type testData struct { + name string + raw string + expected expected + } + + tt := []testData{ + {name: "success_-_empty_single_quoted", raw: `''`, expected: expected{value: ``}}, + {name: "success_-_empty_double_quoted", raw: `""`, expected: expected{value: ``}}, + {name: "success_-_single_character", raw: `'a'`, expected: expected{value: `a`}}, + {name: "success_-_plain_ascii", raw: `'hello world'`, expected: expected{value: `hello world`}}, + {name: "success_-_double_quoted_plain", raw: `"hello"`, expected: expected{value: `hello`}}, + + {name: "success_-_escaped_backslash", raw: `'a\\b'`, expected: expected{value: `a\b`}}, + {name: "success_-_leading_backslash", raw: `'\\foo'`, expected: expected{value: `\foo`}}, + {name: "success_-_trailing_backslash", raw: `'foo\\'`, expected: expected{value: `foo\`}}, + {name: "success_-_only_backslash", raw: `'\\'`, expected: expected{value: `\`}}, + {name: "success_-_two_consecutive_backslashes", raw: `'\\\\'`, expected: expected{value: `\\`}}, + + {name: "success_-_escaped_single_quote_in_single_quoted", raw: `'O\'Brien'`, expected: expected{value: `O'Brien`}}, + {name: "success_-_escaped_double_quote_in_double_quoted", raw: `"say \"hi\""`, expected: expected{value: `say "hi"`}}, + {name: "success_-_escaped_double_quote_in_single_quoted", raw: `'a\"b'`, expected: expected{value: `a"b`}}, + {name: "success_-_escaped_single_quote_in_double_quoted", raw: `"a\'b"`, expected: expected{value: `a'b`}}, + + {name: "success_-_newline_escape", raw: `'a\nb'`, expected: expected{value: "a\nb"}}, + {name: "success_-_carriage_return_escape", raw: `'a\rb'`, expected: expected{value: "a\rb"}}, + {name: "success_-_tab_escape", raw: `'a\tb'`, expected: expected{value: "a\tb"}}, + {name: "success_-_backspace_escape", raw: `'a\bb'`, expected: expected{value: "a\bb"}}, + {name: "success_-_form_feed_escape", raw: `'a\fb'`, expected: expected{value: "a\fb"}}, + + {name: "success_-_uppercase_newline_escape", raw: `'a\Nb'`, expected: expected{value: "a\nb"}}, + {name: "success_-_uppercase_tab_escape", raw: `'a\Tb'`, expected: expected{value: "a\tb"}}, + + {name: "success_-_mssql_style_domain_object_id", raw: `'TEST\\PS1-PSV$@A-1-2-34'`, expected: expected{value: `TEST\PS1-PSV$@A-1-2-34`}}, + {name: "success_-_windows_path", raw: `'C:\\Users\\Admin'`, expected: expected{value: `C:\Users\Admin`}}, + {name: "success_-_mixed_escapes", raw: `'a\\b\nc\td'`, expected: expected{value: "a\\b\nc\td"}}, + + {name: "success_-_raw_utf8_bmp_codepoint", raw: `'café'`, expected: expected{value: "café"}}, + {name: "success_-_raw_utf8_supplementary_codepoint", raw: `'😀'`, expected: expected{value: "😀"}}, + {name: "success_-_raw_utf8_double_quoted", raw: `"日本語"`, expected: expected{value: "日本語"}}, + + {name: "error_-_empty_input", raw: ``, expected: expected{errContains: "invalid cypher string literal"}}, + {name: "error_-_single_quote_only", raw: `'`, expected: expected{errContains: "invalid cypher string literal"}}, + {name: "error_-_no_surrounding_quotes", raw: `foo`, expected: expected{errContains: "missing or mismatched surrounding quotes"}}, + {name: "error_-_unmatched_quote_styles", raw: `'foo"`, expected: expected{errContains: "missing or mismatched surrounding quotes"}}, + {name: "error_-_only_open_quote_with_text", raw: `'foo`, expected: expected{errContains: "missing or mismatched surrounding quotes"}}, + {name: "error_-_only_close_quote_with_text", raw: `foo'`, expected: expected{errContains: "missing or mismatched surrounding quotes"}}, + {name: "error_-_backtick_quoted", raw: "`foo`", expected: expected{errContains: "missing or mismatched surrounding quotes"}}, + {name: "error_-_dangling_backslash", raw: `'foo\'`, expected: expected{errContains: "dangling escape"}}, + {name: "error_-_invalid_escape_letter", raw: `'\x'`, expected: expected{errContains: `invalid escape \x`}}, + {name: "error_-_invalid_escape_digit", raw: `'\1'`, expected: expected{errContains: `invalid escape \1`}}, + {name: "error_-_bare_backslash_before_letter", raw: `'a\Pb'`, expected: expected{errContains: `invalid escape \P`}}, + {name: "error_-_lowercase_unicode_escape_unsupported", raw: `'\u0041'`, expected: expected{errContains: `invalid escape \u`}}, + {name: "error_-_uppercase_unicode_escape_unsupported", raw: `'\U0001F600'`, expected: expected{errContains: `invalid escape \U`}}, + } + + for _, testCase := range tt { + t.Run(testCase.name, func(t *testing.T) { + t.Parallel() + + if got, err := decodeCypherStringLiteral(testCase.raw); testCase.expected.errContains != "" { + require.Error(t, err) + assert.Contains(t, err.Error(), testCase.expected.errContains) + } else { + require.NoError(t, err) + assert.Equal(t, testCase.expected.value, got) + } + }) + } +}