Skip to content

Juanmardefago/rw dialect fixes #5

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: graphops/semantic-type-annotations
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 3 additions & 4 deletions db_proto/sql/risingwave/accumulator_inserter.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,8 @@ func createInsertFromDescriptorAcc(table *schema.Table, dialect sql2.Dialect) (s
fields := table.Columns

var fieldNames []string

// Add standard block metadata columns
fieldNames = append(fieldNames, "block_number")
fieldNames = append(fieldNames, "block_timestamp")

Expand All @@ -82,7 +84,7 @@ func createInsertFromDescriptorAcc(table *schema.Table, dialect sql2.Dialect) (s
continue
}

if field.IsRepeated || field.IsExtension { //not a direct child
if field.IsRepeated || field.IsExtension {
continue
}
fieldNames = append(fieldNames, field.QuotedName())
Expand All @@ -92,11 +94,8 @@ func createInsertFromDescriptorAcc(table *schema.Table, dialect sql2.Dialect) (s
tableName,
strings.Join(fieldNames, ", "),
), nil

}



func (i *AccumulatorInserter) insert(table string, values []any, database *Database) error {
var v []string
if table == "_cursor_" {
Expand Down
143 changes: 60 additions & 83 deletions db_proto/sql/risingwave/dialect.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ import (
"hash/fnv"
"sort"
"strings"
"time"

sql2 "github.com/streamingfast/substreams-sink-sql/db_proto/sql"
"github.com/streamingfast/substreams-sink-sql/db_proto/sql/schema"
Expand All @@ -27,7 +26,7 @@ const risingwaveStaticSql = `
) ON CONFLICT OVERWRITE;

CREATE TABLE IF NOT EXISTS "%s"._blocks_ (
number INTEGER,
number INTEGER PRIMARY KEY,
hash VARCHAR NOT NULL,
timestamp TIMESTAMP WITH TIME ZONE NOT NULL
);
Expand Down Expand Up @@ -68,49 +67,47 @@ func (d *DialectRisingwave) UseDeletedField() bool {
}

func (d *DialectRisingwave) init() error {
d.AddPrimaryKeySql("_blocks_", fmt.Sprintf("alter table %s._blocks_ add constraint block_pk primary key (number);", d.schemaName))
return nil
}

func (d *DialectRisingwave) createTable(table *schema.Table) error {
var sb strings.Builder
addedColumns := make(map[string]struct{})

tableName := d.FullTableName(table)

sb.WriteString(fmt.Sprintf("CREATE TABLE IF NOT EXISTS %s (", tableName))

// Add primary key if it exists
var primaryKeyFieldName string
if table.PrimaryKey != nil {
pk := table.PrimaryKey
primaryKeyFieldName = pk.Name
d.AddPrimaryKeySql(table.Name, fmt.Sprintf("alter table %s add constraint %s_pk primary key (%s);", tableName, table.Name, primaryKeyFieldName))
sb.WriteString(fmt.Sprintf("%s %s,", pk.Name, MapFieldType(pk.FieldDescriptor)))
sb.WriteString(fmt.Sprintf("%s %s PRIMARY KEY,", pk.Name, MapFieldType(pk.FieldDescriptor)))
addedColumns[pk.Name] = struct{}{}
}

// Always add block metadata columns
sb.WriteString(" block_number INTEGER NOT NULL,")
sb.WriteString(" block_timestamp TIMESTAMP WITH TIME ZONE NOT NULL,")
addedColumns["block_number"] = struct{}{}
addedColumns["block_timestamp"] = struct{}{}

// Add parent key for child tables
var parentKeyColumns []string
if table.ChildOf != nil {
parentTable, parentFound := d.TableRegistry[table.ChildOf.ParentTable]
if !parentFound {
return fmt.Errorf("parent table %q not found", table.ChildOf.ParentTable)
}
fieldFound := false
for _, parentField := range parentTable.Columns {

if parentField.Name == table.ChildOf.ParentTableField {

sb.WriteString(fmt.Sprintf("%s %s NOT NULL,", parentField.Name, MapFieldType(parentField.FieldDescriptor)))

foreignKey := &sql2.ForeignKey{
Name: "fk_" + table.ChildOf.ParentTable,
Table: tableName,
Field: table.ChildOf.ParentTableField,
ForeignTable: d.FullTableName(parentTable),
ForeignField: parentField.Name,
if _, exists := addedColumns[parentField.Name]; !exists {
sb.WriteString(fmt.Sprintf("%s %s NOT NULL,", parentField.Name, MapFieldType(parentField.FieldDescriptor)))
addedColumns[parentField.Name] = struct{}{}
parentKeyColumns = append(parentKeyColumns, parentField.Name)
}

d.AddForeignKeySql(table.Name, foreignKey.String())

fieldFound = true
break
}
Expand All @@ -120,34 +117,32 @@ func (d *DialectRisingwave) createTable(table *schema.Table) error {
}
}

// Add all regular columns from the protobuf message
for _, f := range table.Columns {
// Skip if already added
if _, exists := addedColumns[f.Name]; exists {
continue
}

// Skip primary key (already handled above)
if f.Name == primaryKeyFieldName {
continue
}

fieldQuotedName := f.QuotedName()

switch {
case f.IsRepeated:
// Skip repeated fields (not supported in SQL)
if f.IsRepeated {
continue
case f.IsMessage && !IsWellKnownType(f.FieldDescriptor):
childTable, found := d.TableRegistry[f.Message]
if !found {
continue
}
if childTable.PrimaryKey == nil {
continue
}
foreignKey := &sql2.ForeignKey{
Name: "fk_" + childTable.Name,
Table: tableName,
Field: fieldQuotedName,
ForeignTable: d.FullTableName(childTable),
ForeignField: childTable.PrimaryKey.Name,
}
d.AddForeignKeySql(table.Name, foreignKey.String())
}

case f.ForeignKey != nil:
// Skip message fields that don't map to simple columns
if f.IsMessage && !IsWellKnownType(f.FieldDescriptor) {
continue
}

// Handle foreign key fields (but don't add constraints since RisingWave doesn't support them)
if f.ForeignKey != nil {
foreignTable, found := d.TableRegistry[f.ForeignKey.Table]
if !found {
return fmt.Errorf("foreign table %q not found", f.ForeignKey.Table)
Expand All @@ -163,38 +158,48 @@ func (d *DialectRisingwave) createTable(table *schema.Table) error {
if foreignField == nil {
return fmt.Errorf("foreign field %q not found in table %q", f.ForeignKey.TableField, f.ForeignKey.Table)
}

foreignKey := &sql2.ForeignKey{
Name: "fk_" + f.Name,
Table: tableName,
Field: f.Name,
ForeignTable: d.FullTableName(foreignTable),
ForeignField: foreignField.Name,
}
d.AddForeignKeySql(table.Name, foreignKey.String())
}

// Determine field type
fieldType := MapFieldType(f.FieldDescriptor)
if f.IsUnique {
d.AddUniqueConstraintSql(table.Name, fmt.Sprintf("alter table %s add constraint %s_%s_unique unique (%s);", tableName, table.Name, f.Name, fieldQuotedName))
fieldType = fieldType + " UNIQUE"
}

// Add the column
sb.WriteString(fmt.Sprintf("%s %s", fieldQuotedName, fieldType))
sb.WriteString(",")
addedColumns[f.Name] = struct{}{}
}

//removing the last comma since it is complicated to removing it before
temp := sb.String()
temp = temp[:len(temp)-1]
sb = strings.Builder{}
sb.WriteString(temp)
// Add composite primary key if no explicit primary key exists
if table.PrimaryKey == nil {
// Remove the last comma before adding primary key constraint
temp := sb.String()
temp = temp[:len(temp)-1]
sb = strings.Builder{}
sb.WriteString(temp)

// Build composite primary key: always include block_number, then parent keys if any
var pkColumns []string
pkColumns = append(pkColumns, "block_number")
pkColumns = append(pkColumns, parentKeyColumns...)

// Create the primary key constraint
sb.WriteString(fmt.Sprintf(", PRIMARY KEY (%s)", strings.Join(pkColumns, ", ")))
} else {
// Remove the last comma for tables with explicit primary key
temp := sb.String()
temp = temp[:len(temp)-1]
sb = strings.Builder{}
sb.WriteString(temp)
}

sb.WriteString(");\n")
sb.WriteString("\n);\n")

d.AddForeignKeySql(tableName, fmt.Sprintf("ALTER TABLE %s ADD CONSTRAINT fk_block FOREIGN KEY (block_number) REFERENCES %s._blocks_(number) ON DELETE CASCADE", tableName, d.schemaName))
d.AddCreateTableSql(table.Name, sb.String())

return nil

}

func (d *DialectRisingwave) CreateDatabase(tx *sql.Tx) error {
Expand All @@ -214,34 +219,6 @@ func (d *DialectRisingwave) CreateDatabase(tx *sql.Tx) error {
return nil
}

// todo: move to postgres database ...
func (d *DialectRisingwave) ApplyConstraints(tx *sql.Tx) error {
startAt := time.Now()
for _, constraint := range d.PrimaryKeySql {
d.Logger.Info("executing pk statement", zap.String("sql", constraint.Sql))
_, err := tx.Exec(constraint.Sql)
if err != nil {
return fmt.Errorf("executing pk statement: %w %s", err, constraint.Sql)
}
}
for _, constraint := range d.UniqueConstraintSql {
d.Logger.Info("executing unique statement", zap.String("sql", constraint.Sql))
_, err := tx.Exec(constraint.Sql)
if err != nil {
return fmt.Errorf("executing unique statement: %w %s", err, constraint.Sql)
}
}
for _, constraint := range d.ForeignKeySql {
d.Logger.Info("executing fk constraint statement", zap.String("sql", constraint.Sql))
_, err := tx.Exec(constraint.Sql)
if err != nil {
return fmt.Errorf("executing fk constraint statement: %w %s", err, constraint.Sql)
}
}
d.Logger.Info("applying constraints", zap.Duration("duration", time.Since(startAt)))
return nil
}

func (d *DialectRisingwave) FullTableName(table *schema.Table) string {
return tableName(d.schemaName, table.Name)
}
Expand Down
Loading