From d3adfb6bf80f230e72cbdadf8a1d3a805e164f98 Mon Sep 17 00:00:00 2001 From: Racso-3141 Date: Sat, 7 Oct 2023 18:21:15 -0400 Subject: [PATCH] Refactor NewStandardSQLTable to input columns and add test cases for sql_table --- .../datasource/sql_datasource/generic.go | 6 +- .../sql_datasource/sql_datasource.go | 4 +- .../stackql/datasource/sql_table/sql_table.go | 31 ----- .../stackql/datasource/sqltable/sqltable.go | 34 +++++ .../datasource/sqltable/sqltable_test.go | 130 ++++++++++++++++++ 5 files changed, 169 insertions(+), 36 deletions(-) delete mode 100644 internal/stackql/datasource/sql_table/sql_table.go create mode 100644 internal/stackql/datasource/sqltable/sqltable.go create mode 100644 internal/stackql/datasource/sqltable/sqltable_test.go diff --git a/internal/stackql/datasource/sql_datasource/generic.go b/internal/stackql/datasource/sql_datasource/generic.go index db02bcd0..85e46b4a 100644 --- a/internal/stackql/datasource/sql_datasource/generic.go +++ b/internal/stackql/datasource/sql_datasource/generic.go @@ -8,7 +8,7 @@ import ( _ "github.com/snowflakedb/gosnowflake" //nolint:revive,nolintlint // this is a DB driver pattern "github.com/stackql/stackql/internal/stackql/constants" - "github.com/stackql/stackql/internal/stackql/datasource/sql_table" + "github.com/stackql/stackql/internal/stackql/datasource/sqltable" "github.com/stackql/stackql/internal/stackql/db_util" "github.com/stackql/stackql/internal/stackql/dto" ) @@ -67,11 +67,11 @@ func (ds *genericSQLDataSource) Begin() (*sql.Tx, error) { return ds.db.Begin() } -func (ds *genericSQLDataSource) GetTableMetadata(args ...string) (sql_table.SQLTable, error) { +func (ds *genericSQLDataSource) GetTableMetadata(args ...string) (sqltable.SQLTable, error) { return nil, fmt.Errorf("could not obtain sql data source table metadata for args = '%v'", args) } -// func (ds *genericSQLDataSource) getPostgresTableMetadata(schemaName, tableName string) (sql_table.SQLTable, error) { +// func (ds *genericSQLDataSource) getPostgresTableMetadata(schemaName, tableName string) (sqltable.SQLTable, error) { // queryTemplate := ` // SELECT // column_name, diff --git a/internal/stackql/datasource/sql_datasource/sql_datasource.go b/internal/stackql/datasource/sql_datasource/sql_datasource.go index b17fdc01..c9cdd340 100644 --- a/internal/stackql/datasource/sql_datasource/sql_datasource.go +++ b/internal/stackql/datasource/sql_datasource/sql_datasource.go @@ -5,7 +5,7 @@ import ( "fmt" "github.com/stackql/stackql/internal/stackql/constants" - "github.com/stackql/stackql/internal/stackql/datasource/sql_table" + "github.com/stackql/stackql/internal/stackql/datasource/sqltable" "github.com/stackql/stackql/internal/stackql/dto" ) @@ -14,7 +14,7 @@ type SQLDataSource interface { Exec(string, ...interface{}) (sql.Result, error) Query(string, ...interface{}) (*sql.Rows, error) QueryRow(string, ...any) *sql.Row - GetTableMetadata(...string) (sql_table.SQLTable, error) + GetTableMetadata(...string) (sqltable.SQLTable, error) GetSchemaType() string GetDBName() string } diff --git a/internal/stackql/datasource/sql_table/sql_table.go b/internal/stackql/datasource/sql_table/sql_table.go deleted file mode 100644 index e33b9e0b..00000000 --- a/internal/stackql/datasource/sql_table/sql_table.go +++ /dev/null @@ -1,31 +0,0 @@ -package sql_table //nolint:revive,stylecheck // decent package name - -import ( - "github.com/stackql/stackql/internal/stackql/symtab" - "github.com/stackql/stackql/internal/stackql/typing" -) - -type SQLTable interface { - GetColumns() []typing.RelationalColumn - GetSymTab() symtab.SymTab -} - -type standardSQLTable struct { - symTab symtab.SymTab - colz []typing.RelationalColumn -} - -func NewStandardSQLTable(_ []typing.RelationalColumn) (SQLTable, error) { - rv := &standardSQLTable{ - symTab: symtab.NewHashMapTreeSymTab(), - } - return rv, nil -} - -func (sqt *standardSQLTable) GetSymTab() symtab.SymTab { - return sqt.symTab -} - -func (sqt *standardSQLTable) GetColumns() []typing.RelationalColumn { - return sqt.colz -} diff --git a/internal/stackql/datasource/sqltable/sqltable.go b/internal/stackql/datasource/sqltable/sqltable.go new file mode 100644 index 00000000..219275bb --- /dev/null +++ b/internal/stackql/datasource/sqltable/sqltable.go @@ -0,0 +1,34 @@ +package sqltable + +import ( + "github.com/stackql/stackql/internal/stackql/symtab" + "github.com/stackql/stackql/internal/stackql/typing" +) + +type SQLTable interface { + GetColumns() []typing.RelationalColumn + GetSymTab() symtab.SymTab +} + +type StandardSQLTable struct { + symTab symtab.SymTab + columns []typing.RelationalColumn +} + +func NewStandardSQLTable(relationalColumns []typing.RelationalColumn) (SQLTable, error) { + copiedSlice := make([]typing.RelationalColumn, len(relationalColumns)) + copy(copiedSlice, relationalColumns) + rv := &StandardSQLTable{ + symTab: symtab.NewHashMapTreeSymTab(), + columns: copiedSlice, + } + return rv, nil +} + +func (sqt *StandardSQLTable) GetSymTab() symtab.SymTab { + return sqt.symTab +} + +func (sqt *StandardSQLTable) GetColumns() []typing.RelationalColumn { + return sqt.columns +} diff --git a/internal/stackql/datasource/sqltable/sqltable_test.go b/internal/stackql/datasource/sqltable/sqltable_test.go new file mode 100644 index 00000000..ddf2f5d7 --- /dev/null +++ b/internal/stackql/datasource/sqltable/sqltable_test.go @@ -0,0 +1,130 @@ +package sqltable_test + +import ( + "math/rand" + "reflect" + "testing" + "time" + + "github.com/stackql/stackql/internal/stackql/datasource/sqltable" + "github.com/stackql/stackql/internal/stackql/symtab" + "github.com/stackql/stackql/internal/stackql/typing" +) + +func randString() string { + r := rand.New(rand.NewSource(time.Now().UnixNano())) + const letters = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789_" + maxLength := 256 + stringLength := r.Intn(maxLength + 1) // Randomly decide the length of the string + s := make([]byte, stringLength) + for i := range s { + s[i] = letters[r.Intn(len(letters))] + } + return string(s) +} + +func generateRandomColumns(n int) []typing.RelationalColumn { + columns := make([]typing.RelationalColumn, n) + for i := range columns { + // Assuming RelationalColumn is a type like string for simplicity + columns[i] = typing.NewRelationalColumn(randString(), randString()) // Generate a random string of length 10 + } + return columns +} + +func TestNewStandardSQLTable(t *testing.T) { + table, err := sqltable.NewStandardSQLTable(generateRandomColumns(10)) + if err != nil { + t.Fatalf("Expected no error, but got: %v", err) + } + + if table == nil { + t.Fatal("Expected table to be non-nil") + } + + _, ok := table.(*sqltable.StandardSQLTable) + if !ok { + t.Fatal("Expected table to be of type *standardSQLTable") + } +} + +func TestGetSymTab(t *testing.T) { + columns := generateRandomColumns(10) + table, _ := sqltable.NewStandardSQLTable(columns) + + // Initialize symTab with different values + symTab := table.GetSymTab() + + // Set symbols in the symTab + err := symTab.SetSymbol("testKey", symtab.NewSymTabEntry("testType", "testData", "testIn")) + if err != nil { + t.Fatalf("Failed to set symbol: %v", err) + } + + // Test if the symbol was set correctly + entry, exists := symTab.GetSymbol("testKey") + if exists != nil { + t.Fatalf("Symbol not found in symTab") + } + if !reflect.DeepEqual(entry, symtab.NewSymTabEntry("testType", "testData", "testIn")) { + t.Fatalf("Symbol not set correctly in symTab") + } + + // Create a new leaf and set symbols in it + leafSymTab, err := symTab.NewLeaf("testLeafKey") + if err != nil { + t.Fatalf("Failed to create new leaf: %v", err) + } + err = leafSymTab.SetSymbol("leafKey", symtab.NewSymTabEntry("leafType", "leafData", "leafIn")) + if err != nil { + t.Fatalf("Failed to set symbol in leaf: %v", err) + } + + // Test if the symbol was set correctly in the leaf + entry, exists = leafSymTab.GetSymbol("leafKey") + if exists != nil { + t.Fatalf("Symbol not found in leafSymTab") + } + if !reflect.DeepEqual(entry, symtab.NewSymTabEntry("leafType", "leafData", "leafIn")) { + t.Fatalf("Symbol not set correctly in leafSymTab") + } +} + +func TestGetColumns(t *testing.T) { + testCases := []struct { + name string + numColumns int + }{ + { + name: "Test with 0 columns", + numColumns: 0, + }, + { + name: "Test with 5 columns", + numColumns: 5, + }, + { + name: "Test with 10 columns", + numColumns: 10, + }, + { + name: "Test with 15 columns", + numColumns: 15, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + inputColumns := generateRandomColumns(tc.numColumns) + table, err := sqltable.NewStandardSQLTable(inputColumns) + if err != nil { + t.Fatalf("Expected no error, but got: %v", err) + } + + returnedColumns := table.GetColumns() + if !reflect.DeepEqual(returnedColumns, inputColumns) { + t.Fatalf("Expected columns %v, but got %v", inputColumns, returnedColumns) + } + }) + } +}