Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions internal/stackql/mcpbackend/mcp_reverse_proxy_backend_service.go
Original file line number Diff line number Diff line change
Expand Up @@ -220,3 +220,19 @@ func (b *stackqlMCPReverseProxyService) ListMethods(ctx context.Context, hI dto.
}
return b.query(ctx, q, hI.RowLimit)
}

func (b *stackqlMCPReverseProxyService) ListRegistry(ctx context.Context, input dto.RegistryInput) ([]map[string]interface{}, error) {
q, qErr := b.interrogator.GetRegistryList(input.Provider)
if qErr != nil {
return nil, qErr
}
return b.query(ctx, q, unlimitedRowLimit)
}

func (b *stackqlMCPReverseProxyService) PullProvider(ctx context.Context, input dto.RegistryInput) (map[string]any, error) {
q, qErr := b.interrogator.GetRegistryPull(input.Provider, input.Version)
if qErr != nil {
return nil, qErr
}
return b.ExecQuery(ctx, q)
}
120 changes: 102 additions & 18 deletions internal/stackql/mcpbackend/mcp_service_stackql.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"strings"

"github.com/sirupsen/logrus"
"github.com/stackql/psql-wire/pkg/sqldata"
"github.com/stackql/stackql/internal/stackql/acid/tsm_physio"
"github.com/stackql/stackql/internal/stackql/buildinfo"
"github.com/stackql/stackql/internal/stackql/handler"
Expand All @@ -21,6 +22,11 @@ var (

const (
unlimitedRowLimit int = -1
// forbiddenRegistryCharacters mirrors the CLI registry command's guard
// (see internal/stackql/cmd/registry.go). Interrogator methods that
// interpolate user-supplied registry identifiers reject these characters
// rather than substituting / escaping them, matching CLI semantics.
forbiddenRegistryCharacters string = ` ;\`
)

// serverBuildInfo carries the runtime + build-time metadata reported by the
Expand Down Expand Up @@ -90,6 +96,8 @@ type StackqlInterrogator interface {
GetDescribeResource(dto.HierarchyInput) (string, error)
GetDescribeMethod(dto.HierarchyInput) (string, error)
GetQueryJSON(dto.QueryJSONInput) (string, error)
GetRegistryList(provider string) (string, error)
GetRegistryPull(provider, version string) (string, error)
}

type simpleStackqlInterrogator struct{}
Expand Down Expand Up @@ -192,6 +200,39 @@ func (s *simpleStackqlInterrogator) GetQueryJSON(qI dto.QueryJSONInput) (string,
return qI.SQL, nil
}

func (s *simpleStackqlInterrogator) GetRegistryList(provider string) (string, error) {
if provider != "" && strings.ContainsAny(provider, forbiddenRegistryCharacters) {
return "", fmt.Errorf("forbidden characters in provider")
}
sb := strings.Builder{}
sb.WriteString("REGISTRY LIST")
if provider != "" {
sb.WriteString(" ")
sb.WriteString(provider)
}
sb.WriteString(";")
return sb.String(), nil
}

func (s *simpleStackqlInterrogator) GetRegistryPull(provider, version string) (string, error) {
if provider == "" {
return "", fmt.Errorf("provider not specified")
}
if strings.ContainsAny(provider, forbiddenRegistryCharacters) ||
strings.ContainsAny(version, forbiddenRegistryCharacters) {
return "", fmt.Errorf("forbidden characters in provider or version")
}
sb := strings.Builder{}
sb.WriteString("REGISTRY PULL ")
sb.WriteString(provider)
if version != "" {
sb.WriteString(" ")
sb.WriteString(version)
}
sb.WriteString(";")
return sb.String(), nil
}

type stackqlMCPService struct {
txnOrchestrator tsm_physio.Orchestrator
interrogator StackqlInterrogator
Expand Down Expand Up @@ -317,34 +358,61 @@ func (b *stackqlMCPService) applyQuery(query string) ([]internaldto.ExecutorOutp

func (b *stackqlMCPService) extractQueryResults(query string, rowLimit int) ([]map[string]interface{}, bool) {
r, ok := b.applyQuery(query)
var rv []map[string]interface{}
// Initialise as empty (not nil) so a zero-row result survives downstream
// JSON-array schema validation on QueryResultDTO.Rows. This pairs with
// fix 1 (returning ok regardless of len(rv)) so empty results render as
// "**no results**" rather than failing extraction.
rv := []map[string]interface{}{}
rowCount := 0
for _, resp := range r {
if respErr := resp.GetError(); respErr != nil {
ok = false
break
}
// PrepareResultSet emits a nil SQLResult when RowMap is empty (eg
// REGISTRY LIST against an empty registry). That's a zero-row
// result, not an extraction failure: skip the stream and continue.
sqlRowStream := resp.GetSQLResult()
if sqlRowStream == nil {
continue
}
var drainOK bool
rv, rowCount, drainOK = drainSQLRowStream(sqlRowStream, rv, rowCount, rowLimit)
if !drainOK {
ok = false
break
}
for {
row, err := sqlRowStream.Read()
if err == io.EOF {
rowArr := row.ToArr()
rv = append(rv, rowArr...)
break
}
if err != nil || row == nil {
ok = false
break
}
rowArr := row.ToArr()
rv = append(rv, rowArr...)
rowCount += len(rowArr)
if rowLimit > 0 && rowCount >= rowLimit {
break
}
return rv, ok
}

// drainSQLRowStream reads `stream` to EOF (or until rowLimit is reached),
// appending each row's payload to `rv`. The returned bool is false when the
// stream surfaces a read error or a nil row outside of EOF; that maps onto
// extractQueryResults' (rv, false) failure mode.
func drainSQLRowStream(
stream sqldata.ISQLResultStream,
rv []map[string]interface{},
rowCount, rowLimit int,
) ([]map[string]interface{}, int, bool) {
for {
row, err := stream.Read()
if err == io.EOF {
if row != nil {
rv = append(rv, row.ToArr()...)
}
return rv, rowCount, true
}
if err != nil || row == nil {
return rv, rowCount, false
}
rowArr := row.ToArr()
rv = append(rv, rowArr...)
rowCount += len(rowArr)
if rowLimit > 0 && rowCount >= rowLimit {
return rv, rowCount, true
}
}
return rv, (ok && len(rv) > 0)
}

func (b *stackqlMCPService) DescribeResource(ctx context.Context, hI dto.HierarchyInput) ([]map[string]interface{}, error) {
Expand Down Expand Up @@ -394,3 +462,19 @@ func (b *stackqlMCPService) ListMethods(ctx context.Context, hI dto.HierarchyInp
}
return b.runPreprocessedQueryJSON(ctx, q, unlimitedRowLimit)
}

func (b *stackqlMCPService) ListRegistry(ctx context.Context, input dto.RegistryInput) ([]map[string]interface{}, error) {
q, qErr := b.interrogator.GetRegistryList(input.Provider)
if qErr != nil {
return nil, qErr
}
return b.runPreprocessedQueryJSON(ctx, q, unlimitedRowLimit)
}

func (b *stackqlMCPService) PullProvider(ctx context.Context, input dto.RegistryInput) (map[string]any, error) {
q, qErr := b.interrogator.GetRegistryPull(input.Provider, input.Version)
if qErr != nil {
return nil, qErr
}
return b.ExecQuery(ctx, q)
}
11 changes: 11 additions & 0 deletions pkg/mcp_server/backend.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,17 @@ type Backend interface {

// DescribeMethod returns the full I/O contract for one method.
DescribeMethod(ctx context.Context, hI dto.HierarchyInput) ([]map[string]any, error)

// ListRegistry lists providers (and their versions) available in the registry.
// When input.Provider is empty, lists all available providers; otherwise lists
// versions for that provider. Distinct from ListProviders, which lists only
// providers already pulled into the local cache.
ListRegistry(ctx context.Context, input dto.RegistryInput) ([]map[string]any, error)

// PullProvider installs a provider from the registry into the local approot
// cache. input.Provider is required; input.Version is optional (empty pulls
// the latest published version). Returns the same shape as ExecQuery.
PullProvider(ctx context.Context, input dto.RegistryInput) (map[string]any, error)
}

// QueryResult represents the result of a query execution.
Expand Down
9 changes: 9 additions & 0 deletions pkg/mcp_server/dto/dto.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,15 @@ type QueryJSONInput struct {
RowLimit int `json:"row_limit,omitempty" yaml:"row_limit,omitempty"`
}

// RegistryInput is the shared input shape for list_registry and pull_provider.
// list_registry treats Provider as optional (when empty, lists all available
// providers); pull_provider requires Provider and treats Version as optional
// (when empty, the latest published version is pulled).
type RegistryInput struct {
Provider string `json:"provider,omitempty" yaml:"provider,omitempty"`
Version string `json:"version,omitempty" yaml:"version,omitempty"`
}

// QueryResultDTO is the typed structured payload returned alongside the rendered text.
type QueryResultDTO struct {
Rows []map[string]any `json:"rows"`
Expand Down
8 changes: 8 additions & 0 deletions pkg/mcp_server/example_backend.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,14 @@ func (b *ExampleBackend) ListResources(ctx context.Context, hI dto.HierarchyInpu
return []map[string]any{}, nil
}

func (b *ExampleBackend) ListRegistry(ctx context.Context, input dto.RegistryInput) ([]map[string]any, error) {
return []map[string]any{}, nil
}

func (b *ExampleBackend) PullProvider(ctx context.Context, input dto.RegistryInput) (map[string]any, error) {
return map[string]any{}, nil
}

// NewExampleBackend creates a new example backend instance.
func NewExampleBackend(connectionString string) Backend {
return &ExampleBackend{
Expand Down
16 changes: 16 additions & 0 deletions pkg/mcp_server/gate.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,22 @@ func extractArgsFromHierarchy(args any) map[string]any {
return hierarchyToMap(v)
}

// extractArgsFromRegistryInput returns {provider, version} for audit recording.
func extractArgsFromRegistryInput(args any) map[string]any {
v, ok := args.(dto.RegistryInput)
if !ok {
return nil
}
out := map[string]any{}
if v.Provider != "" {
out["provider"] = v.Provider
}
if v.Version != "" {
out["version"] = v.Version
}
return out
}

func hierarchyToMap(v dto.HierarchyInput) map[string]any {
out := map[string]any{}
if v.Provider != "" {
Expand Down
45 changes: 43 additions & 2 deletions pkg/mcp_server/render/render.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,53 @@ package render

import (
"fmt"
"reflect"
"sort"
"strings"
)

const noResults = "**no results**"

// unwrap normalises database/sql nullable wrappers (sql.NullString, NullBool,
// NullInt64, NullInt32, NullFloat64, NullByte, NullTime, the generic sql.Null[T])
// down to their scalar payload before formatting. Anything else is returned
// unchanged. Invalid wrappers collapse to "" so cells render empty rather than
// as the typed zero value.
func unwrap(v any) any {
if v == nil {
return nil
}
rv := reflect.ValueOf(v)
if rv.Kind() == reflect.Ptr {
if rv.IsNil() {
return nil
}
rv = rv.Elem()
}
if rv.Kind() != reflect.Struct {
return v
}
validField := rv.FieldByName("Valid")
if !validField.IsValid() || validField.Kind() != reflect.Bool {
return v
}
if !validField.Bool() {
return ""
}
return firstNonValidField(rv)
}

// firstNonValidField returns the first exported struct field whose name is not
// "Valid". Split out of unwrap to keep gocognit complexity low.
func firstNonValidField(rv reflect.Value) any {
for i := 0; i < rv.NumField(); i++ {
if rv.Type().Field(i).Name != "Valid" {
return rv.Field(i).Interface()
}
}
return rv.Interface()
}

// RenderTable renders a uniform multi-row result set as a markdown table.
// Column order is stable: the union of keys across all rows, sorted alphabetically.
func RenderTable(rows []map[string]any) string {
Expand Down Expand Up @@ -44,7 +85,7 @@ func RenderKV(title string, records []map[string]any) string {
sb.WriteString(fmt.Sprintf("## Record %d\n\n", i+1))
keys := sortedKeys(rec)
for _, k := range keys {
sb.WriteString(fmt.Sprintf("%s: %v\n", k, rec[k]))
sb.WriteString(fmt.Sprintf("%s: %v\n", k, unwrap(rec[k])))
}
if i < len(records)-1 {
sb.WriteString("\n")
Expand Down Expand Up @@ -103,7 +144,7 @@ func dataRow(columns []string, row map[string]any) string {
sb.WriteString("| ")
continue
}
sb.WriteString(fmt.Sprintf("| %v ", v))
sb.WriteString(fmt.Sprintf("| %v ", unwrap(v)))
}
sb.WriteString("|")
return sb.String()
Expand Down
50 changes: 50 additions & 0 deletions pkg/mcp_server/render/render_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package render_test

import (
"database/sql"
"strings"
"testing"

Expand Down Expand Up @@ -48,3 +49,52 @@ func TestRenderKV_Empty(t *testing.T) {
t.Fatalf("expected 'no results' message: %q", got)
}
}

// Issue #661 fix 2: nullable wrappers (and pointers to them) must render as
// scalars, not as Go default-format struct text like "&{ok true}".
func TestRenderTable_UnwrapsNullableWrappers(t *testing.T) {
rows := []map[string]any{{
"s": &sql.NullString{String: "ok", Valid: true},
"b": &sql.NullBool{Bool: true, Valid: true},
}}
got := render.RenderTable(rows)
if strings.Contains(got, "&{") {
t.Errorf("table should not contain Go wrapper text: %q", got)
}
if !strings.Contains(got, "| ok |") {
t.Errorf("expected unwrapped string value, got %q", got)
}
if !strings.Contains(got, "| true |") {
t.Errorf("expected unwrapped bool value, got %q", got)
}
}

func TestRenderKV_UnwrapsNullableWrappers(t *testing.T) {
rec := []map[string]any{{
"s": sql.NullString{String: "ok", Valid: true},
"b": &sql.NullBool{Bool: false, Valid: true},
}}
got := render.RenderKV("Sample", rec)
if strings.Contains(got, "&{") || strings.Contains(got, "{ok") {
t.Errorf("kv should not contain Go wrapper text: %q", got)
}
if !strings.Contains(got, "s: ok") {
t.Errorf("expected unwrapped string line, got %q", got)
}
if !strings.Contains(got, "b: false") {
t.Errorf("expected unwrapped bool line, got %q", got)
}
}

func TestRender_InvalidNullableRendersAsEmpty(t *testing.T) {
rows := []map[string]any{{
"s": sql.NullString{String: "ignored", Valid: false},
}}
got := render.RenderTable(rows)
if strings.Contains(got, "ignored") {
t.Errorf("invalid Nullable should not surface payload, got %q", got)
}
if !strings.Contains(got, "| |") {
t.Errorf("expected empty cell for invalid Nullable, got %q", got)
}
}
Loading
Loading