Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

# User IDE Files
.idea
.zed

# Local integration test datasets
integration/testdata/local/
integration/testdata/local/
16 changes: 15 additions & 1 deletion cypher/models/pgsql/format/format.go
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,13 @@ func formatNode(builder *OutputBuilder, rootExpr pgsql.SyntaxNode) error {

exprStack = append(exprStack, typedNextExpr.Name)

case pgsql.LateralSubquery:
if typedNextExpr.Binding.Set {
exprStack = append(exprStack, typedNextExpr.Binding.Value, pgsql.FormattingLiteral(" "))
}

exprStack = append(exprStack, pgsql.FormattingLiteral(")"), typedNextExpr.Query, pgsql.FormattingLiteral("lateral ("))

case pgsql.Assignment:
exprStack = append(exprStack,
typedNextExpr,
Expand Down Expand Up @@ -532,6 +539,10 @@ func Expression(expression pgsql.SyntaxNode, builder *OutputBuilder) (string, er
func formatSelect(builder *OutputBuilder, selectStmt pgsql.Select) error {
builder.Write("select ")

if selectStmt.Distinct {
builder.Write("distinct ")
}

for idx, projection := range selectStmt.Projection {
if idx > 0 {
builder.Write(", ")
Expand Down Expand Up @@ -783,6 +794,9 @@ func formatSetExpression(builder *OutputBuilder, expression pgsql.SetExpression)
case pgsql.Values:
return formatNode(builder, typedSetExpression)

case pgsql.Insert:
return formatInsertStatement(builder, typedSetExpression)

case pgsql.Update:
return formatUpdateStatement(builder, typedSetExpression)

Expand Down Expand Up @@ -909,7 +923,7 @@ func formatInsertStatement(builder *OutputBuilder, insert pgsql.Insert) error {
return err
}

if len(insert.Shape.Columns) > 0 {
if insert.Shape != nil && len(insert.Shape.Columns) > 0 {
builder.Write(" (")

for idx, column := range insert.Shape.Columns {
Expand Down
67 changes: 67 additions & 0 deletions cypher/models/pgsql/format/format_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,73 @@ func TestFormat_TypeCastedParenthetical(t *testing.T) {
require.Equal(t, "('str')::text", formattedQuery)
}

func TestFormat_SelectDistinct(t *testing.T) {
formattedQuery, err := format.Statement(pgsql.Query{
Body: pgsql.Select{
Distinct: true,
Projection: []pgsql.SelectItem{
pgsql.Identifier("id"),
},
From: []pgsql.FromClause{{
Source: pgsql.TableReference{
Name: pgsql.CompoundIdentifier{"node"},
},
}},
},
}, format.NewOutputBuilder())

require.Nil(t, err)
require.Equal(t, "select distinct id from node;", formattedQuery)
}

func TestFormat_LateralSubqueryJoin(t *testing.T) {
formattedQuery, err := format.Statement(pgsql.Query{
Body: pgsql.Select{
Projection: []pgsql.SelectItem{
pgsql.CompoundIdentifier{"n", "id"},
pgsql.CompoundIdentifier{"e", "id"},
},
From: []pgsql.FromClause{{
Source: pgsql.TableReference{
Name: pgsql.CompoundIdentifier{"node"},
Binding: pgsql.AsOptionalIdentifier("n"),
},
Joins: []pgsql.Join{{
Table: pgsql.LateralSubquery{
Query: pgsql.Query{
Body: pgsql.Select{
Projection: []pgsql.SelectItem{
pgsql.CompoundIdentifier{"e", "id"},
},
From: []pgsql.FromClause{{
Source: pgsql.TableReference{
Name: pgsql.CompoundIdentifier{"edge"},
Binding: pgsql.AsOptionalIdentifier("e"),
},
}},
Where: pgsql.NewBinaryExpression(
pgsql.CompoundIdentifier{"e", "start_id"},
pgsql.OperatorEquals,
pgsql.CompoundIdentifier{"n", "id"},
),
},
Offset: pgsql.NewLiteral(0, pgsql.Int),
},
Binding: pgsql.AsOptionalIdentifier("e"),
},
JoinOperator: pgsql.JoinOperator{
JoinType: pgsql.JoinTypeInner,
Constraint: pgsql.NewLiteral(true, pgsql.Boolean),
},
}},
}},
},
}, format.NewOutputBuilder())

require.Nil(t, err)
require.Equal(t, "select n.id, e.id from node n join lateral (select e.id from edge e where e.start_id = n.id offset 0) e on true;", formattedQuery)
}

func TestFormat_Delete(t *testing.T) {
formattedQuery, err := format.Statement(pgsql.Delete{
From: []pgsql.TableReference{{
Expand Down
3 changes: 3 additions & 0 deletions cypher/models/pgsql/functions.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ const (
FunctionUnidirectionalASPHarness Identifier = "unidirectional_asp_harness"
FunctionUnidirectionalSPHarness Identifier = "unidirectional_sp_harness"
FunctionBidirectionalASPHarness Identifier = "bidirectional_asp_harness"
FunctionBidirectionalSPHarness Identifier = "bidirectional_sp_harness"
FunctionIntArrayUnique Identifier = "uniq"
FunctionIntArraySort Identifier = "sort"
FunctionJSONBToTextArray Identifier = "jsonb_to_text_array"
Expand All @@ -26,6 +27,8 @@ const (
FunctionToUpper Identifier = "upper"
FunctionCoalesce Identifier = "coalesce"
FunctionUnnest Identifier = "unnest"
FunctionNextValue Identifier = "nextval"
FunctionPGGetSerialSequence Identifier = "pg_get_serial_sequence"
FunctionJSONBSet Identifier = "jsonb_set"
FunctionCount Identifier = "count"
FunctionStringToArray Identifier = "string_to_array"
Expand Down
23 changes: 22 additions & 1 deletion cypher/models/pgsql/model.go
Original file line number Diff line number Diff line change
Expand Up @@ -564,7 +564,7 @@ func (s FunctionCall) TypeHint() DataType {
}

type Join struct {
Table TableReference
Table Expression
JoinOperator JoinOperator
}

Expand Down Expand Up @@ -724,6 +724,19 @@ func (s TableReference) NodeType() string {
return "table_reference"
}

type LateralSubquery struct {
Query Query
Binding models.Optional[Identifier]
}

func (s LateralSubquery) AsExpression() Expression {
return s
}

func (s LateralSubquery) NodeType() string {
return "lateral_subquery"
}

type FromClause struct {
Source Expression
Joins []Join
Expand Down Expand Up @@ -956,6 +969,14 @@ type Insert struct {
Returning []SelectItem
}

func (s Insert) AsExpression() Expression {
return s
}

func (s Insert) AsSetExpression() SetExpression {
return s
}

func (s Insert) AsStatement() Statement {
return s
}
Expand Down
2 changes: 1 addition & 1 deletion cypher/models/pgsql/test/query_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ func TestQuery_KindGeneratesInclusiveKindMatcher(t *testing.T) {
t.Errorf("could not build query: %v", err)
}

translatedQuery, err := translate.Translate(context.Background(), builtQuery, mapper, nil)
translatedQuery, err := translate.Translate(context.Background(), builtQuery, mapper, nil, translate.DefaultGraphID)
if err != nil {
t.Errorf("could not translate query: %#v: %v", builtQuery, err)
}
Expand Down
11 changes: 8 additions & 3 deletions cypher/models/pgsql/test/testcase.go
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ func (s *TranslationTestCase) WriteTo(output io.Writer, kindMapper pgsql.KindMap
}
}

if translation, err := translate.Translate(context.Background(), regularQuery, kindMapper, nil); err != nil {
if translation, err := translate.Translate(context.Background(), regularQuery, kindMapper, nil, translate.DefaultGraphID); err != nil {
return err
} else if formattedQuery, err := translate.Translated(translation); err != nil {
return err
Expand Down Expand Up @@ -164,7 +164,7 @@ func (s *TranslationTestCase) Assert(t *testing.T, expectedSQL string, kindMappe
}
}

if translation, err := translate.Translate(context.Background(), regularQuery, kindMapper, nil); err != nil {
if translation, err := translate.Translate(context.Background(), regularQuery, kindMapper, nil, translate.DefaultGraphID); err != nil {
t.Fatalf("Failed to translate cypher query: %s - %v", s.Cypher, err)
} else if formattedQuery, err := translate.Translated(translation); err != nil {
t.Fatalf("Failed to format SQL translatedQuery: %v", err)
Expand Down Expand Up @@ -200,7 +200,12 @@ func (s *TranslationTestCase) AssertLive(ctx context.Context, t *testing.T, driv
}
}

if translation, err := translate.Translate(context.Background(), regularQuery, driver.KindMapper(), s.CypherParams); err != nil {
defaultGraph, hasDefaultGraph := driver.DefaultGraph()
if !hasDefaultGraph {
t.Fatalf("Driver has no default graph set")
}

if translation, err := translate.Translate(context.Background(), regularQuery, driver.KindMapper(), s.CypherParams, defaultGraph.ID); err != nil {
t.Fatalf("Failed to translate cypher query: %s - %v", s.Cypher, err)
} else if formattedQuery, err := translate.Translated(translation); err != nil {
t.Fatalf("Failed to format SQL translatedQuery: %v", err)
Expand Down
Loading
Loading