diff --git a/.gitignore b/.gitignore index ed0afa4ad..2bebf504f 100644 --- a/.gitignore +++ b/.gitignore @@ -10,8 +10,6 @@ tmp/ *.gz .DS_Store autogen -dql -dql *.iml *.so logs diff --git a/Version b/Version index 014ec6192..fcc9d59a4 100644 --- a/Version +++ b/Version @@ -1 +1 @@ -v0.20.2 \ No newline at end of file +v0.21.0 \ No newline at end of file diff --git a/cmd/cli.go b/cmd/cli.go index 166f1d8f7..c2455c126 100644 --- a/cmd/cli.go +++ b/cmd/cli.go @@ -3,13 +3,13 @@ package cmd import ( "context" "fmt" + "github.com/jessevdk/go-flags" "github.com/viant/datly/cmd/command" soptions "github.com/viant/datly/cmd/options" ) func RunApp(version string, args soptions.Arguments) error { - options, err := buildOptions(args) if err != nil { return err diff --git a/cmd/command/async_defaults.go b/cmd/command/async_defaults.go new file mode 100644 index 000000000..7ffbd4c43 --- /dev/null +++ b/cmd/command/async_defaults.go @@ -0,0 +1,33 @@ +package command + +import ( + "github.com/viant/afs/file" + "github.com/viant/afs/url" + "github.com/viant/datly/gateway/runtime/standalone" +) + +const ( + defaultJobURL = "/tmp/datly/jobs" + defaultFailedJobURL = "/tmp/datly/failed" +) + +func applyAsyncJobDefaults(config *standalone.Config) { + if config == nil { + return + } + + if config.JobURL == "" && config.FailedJobURL == "" { + config.JobURL = defaultJobURL + config.FailedJobURL = defaultFailedJobURL + return + } + + if config.JobURL == "" { + config.JobURL = defaultJobURL + } + + if config.FailedJobURL == "" { + parent, _ := url.Split(config.JobURL, file.Scheme) + config.FailedJobURL = url.Join(parent, "failed", "jobs") + } +} diff --git a/cmd/command/async_defaults_test.go b/cmd/command/async_defaults_test.go new file mode 100644 index 000000000..736e45492 --- /dev/null +++ b/cmd/command/async_defaults_test.go @@ -0,0 +1,51 @@ +package command + +import ( + "testing" + + "github.com/viant/datly/gateway" + "github.com/viant/datly/gateway/runtime/standalone" +) + +func TestApplyAsyncJobDefaults(t *testing.T) { + testCases := []struct { + name string + jobURL string + failedJobURL string + expectJob string + expectFailed string + }{ + { + name: "both empty use tmp defaults", + expectJob: "/tmp/datly/jobs", + expectFailed: "/tmp/datly/failed", + }, + { + name: "custom job only derives failed path", + jobURL: "/custom/jobs", + expectJob: "/custom/jobs", + expectFailed: "file://localhost/custom/failed/jobs", + }, + { + name: "failed only keeps failed and defaults job", + failedJobURL: "/custom/failed", + expectJob: "/tmp/datly/jobs", + expectFailed: "/custom/failed", + }, + } + + for _, testCase := range testCases { + cfg := &standalone.Config{Config: &gateway.Config{}} + cfg.JobURL = testCase.jobURL + cfg.FailedJobURL = testCase.failedJobURL + + applyAsyncJobDefaults(cfg) + + if cfg.JobURL != testCase.expectJob { + t.Fatalf("%s: expected JobURL=%s, got %s", testCase.name, testCase.expectJob, cfg.JobURL) + } + if cfg.FailedJobURL != testCase.expectFailed { + t.Fatalf("%s: expected FailedJobURL=%s, got %s", testCase.name, testCase.expectFailed, cfg.FailedJobURL) + } + } +} diff --git a/cmd/command/generate.go b/cmd/command/generate.go index e82c89245..c516f91d5 100644 --- a/cmd/command/generate.go +++ b/cmd/command/generate.go @@ -42,6 +42,9 @@ func (s *Service) generate(ctx context.Context, options *options.Options) error if _, err := s.loadPlugin(ctx, options); err != nil { return err } + if ruleOption.EffectiveEngine() == "shape" && options.Generate.Operation != "get" { + return fmt.Errorf("shape engine currently supports gen get only") + } if options.Generate.Operation == "get" { return s.generateGet(ctx, options) } @@ -144,8 +147,51 @@ func (s *Service) generateGet(ctx context.Context, opts *options.Options) (err e if err = s.translate(ctx, opts); err != nil { return err } - if err = s.persistRepository(ctx); err != nil { - return err + if opts.Rule().EffectiveEngine() != options.EngineShape { + if err = s.persistRepository(ctx); err != nil { + return err + } + } + + if opts.Rule().EffectiveEngine() == options.EngineShape { + componentURL := url.Join(translate.Repository.RepositoryURL, "Datly", "routes") + datlySrv, err := datly.New(ctx, repository.WithComponentURL(componentURL)) + if err != nil { + return err + } + for i, source := range sources { + translate.Rule.Index = i + sourceText, loadErr := translate.Rule.LoadSource(ctx, s.fs, source) + if loadErr != nil { + return loadErr + } + method, uri := parseShapeRulePath(sourceText, translate.Rule.RuleName(), translate.Repository.APIPrefix) + key := uri + if !strings.EqualFold(method, "GET") { + key = method + ":" + uri + } + aComponent, compErr := datlySrv.Component(ctx, key) + if compErr != nil { + return compErr + } + applyDefaultComponentPackage(aComponent, translate.Rule.ModulePrefix) + _, sourceName := path.Split(url.Path(source)) + sourceName = trimExt(sourceName) + var embeds = map[string]string{} + var namedResources []string + if repo := opts.Repository(); repo != nil && len(repo.SubstitutesURL) > 0 { + namedResources = append(namedResources, repo.SubstitutesURL...) + } + code := aComponent.GenerateOutputCode(ctx, defComp, true, embeds, namedResources...) + destURL := path.Join(translate.Rule.ModuleLocation, translate.Rule.ModulePrefix, sourceName+".go") + if err = s.fs.Upload(ctx, destURL, file.DefaultFileOsMode, strings.NewReader(code)); err != nil { + return err + } + if err = s.persistEmbeds(ctx, translate.Rule.ModuleLocation, translate.Rule.ModulePrefix, embeds, aComponent); err != nil { + return err + } + } + return nil } for i, resource := range s.translator.Repository.Resource { @@ -167,6 +213,7 @@ func (s *Service) generateGet(ctx context.Context, opts *options.Options) (err e if err != nil { return err } + applyDefaultComponentPackage(aComponent, modulePrefix) var embeds = map[string]string{} var namedResources []string @@ -190,6 +237,24 @@ func (s *Service) generateGet(ctx context.Context, opts *options.Options) (err e return nil } +func applyDefaultComponentPackage(component *repository.Component, modulePrefix string) { + if component == nil { + return + } + if component.Output.Type.Package != "" || component.Input.Type.Package != "" { + return + } + modulePrefix = strings.Trim(modulePrefix, "/") + if modulePrefix == "" { + return + } + base := path.Base(modulePrefix) + if base == "" || base == "." || base == "/" { + return + } + component.Output.Type.Package = strings.ReplaceAll(base, "-", "_") +} + func (s *Service) persistEmbeds(ctx context.Context, moduleLocation string, modulePrefix string, embeds map[string]string, component *repository.Component) error { rootName := component.View.Name formatter := text.DetectCaseFormat(rootName) diff --git a/cmd/command/mcp.go b/cmd/command/mcp.go index d1f0c4989..bbdb84897 100644 --- a/cmd/command/mcp.go +++ b/cmd/command/mcp.go @@ -24,10 +24,7 @@ func (s *Service) mcp(ctx context.Context, mcpOption *options.Mcp) error { setter.SetStringIfEmpty(&s.config.JobURL, mcpOption.JobURL) setter.SetStringIfEmpty(&s.config.FailedJobURL, mcpOption.FailedJobURL) setter.SetIntIfZero(&s.config.MaxJobs, mcpOption.MaxJobs) - if s.config.FailedJobURL == "" && s.config.JobURL != "" { - parent, _ := url.Split(s.config.JobURL, file.Scheme) - s.config.FailedJobURL = url.Join(parent, "failed", "jobs") - } + applyAsyncJobDefaults(s.config) if mcpOption.LoadPlugin && s.config.Config.PluginsURL != "" { parent, _ := url.Split(mcpOption.PluginInfo, file.Scheme) _ = s.fs.Copy(ctx, parent, s.config.Config.PluginsURL) diff --git a/cmd/command/plugin.go b/cmd/command/plugin.go index df9348e39..77b52f10e 100644 --- a/cmd/command/plugin.go +++ b/cmd/command/plugin.go @@ -190,7 +190,7 @@ func (s *Service) reportPluginIssue(ctx context.Context, destURL string) error { if fixBuilder.Len() > 0 { fmt.Printf("[FIXME]: to address pulugin dependency run the following:\n") } - fmt.Printf(fixBuilder.String()) + fmt.Print(fixBuilder.String()) return nil } diff --git a/cmd/command/run.go b/cmd/command/run.go index c55c3e256..a084275d2 100644 --- a/cmd/command/run.go +++ b/cmd/command/run.go @@ -36,10 +36,7 @@ func (s *Service) run(ctx context.Context, run *options.Run) (*standalone.Server setter.SetStringIfEmpty(&s.config.JobURL, run.JobURL) setter.SetStringIfEmpty(&s.config.FailedJobURL, run.FailedJobURL) setter.SetIntIfZero(&s.config.MaxJobs, run.MaxJobs) - if s.config.FailedJobURL == "" && s.config.JobURL != "" { - parent, _ := url.Split(s.config.JobURL, file.Scheme) - s.config.FailedJobURL = url.Join(parent, "failed", "jobs") - } + applyAsyncJobDefaults(s.config) if run.LoadPlugin && s.config.Config.PluginsURL != "" { parent, _ := url.Split(run.PluginInfo, file.Scheme) _ = s.fs.Copy(ctx, parent, s.config.Config.PluginsURL) diff --git a/cmd/command/translate.go b/cmd/command/translate.go index 0eea2bbaf..4b4f0921c 100644 --- a/cmd/command/translate.go +++ b/cmd/command/translate.go @@ -29,6 +29,10 @@ func (s *Service) Translate(ctx context.Context, opts *options.Options) (err err if err = s.translate(ctx, opts); err != nil { return err } + engine := opts.Rule().EffectiveEngine() + if engine == options.EngineShape || engine == options.EngineShapeIR { + return nil + } return s.persistRepository(ctx) } @@ -49,6 +53,12 @@ func (s *Service) persistRepository(ctx context.Context) error { } func (s *Service) translate(ctx context.Context, opts *options.Options) error { + switch opts.Rule().EffectiveEngine() { + case options.EngineShape: + return s.translateShape(ctx, opts) + case options.EngineShapeIR: + return s.translateShapeIR(ctx, opts) + } if err := s.ensureTranslator(opts); err != nil { return fmt.Errorf("failed to create translator: %v", err) } diff --git a/cmd/command/translate_shape.go b/cmd/command/translate_shape.go new file mode 100644 index 000000000..fa12d8f03 --- /dev/null +++ b/cmd/command/translate_shape.go @@ -0,0 +1,221 @@ +package command + +import ( + "context" + "encoding/json" + "fmt" + "path" + "path/filepath" + "strings" + + "github.com/viant/afs/file" + "github.com/viant/afs/url" + "github.com/viant/datly/cmd/options" + "github.com/viant/datly/repository" + "github.com/viant/datly/repository/contract" + "github.com/viant/datly/repository/shape" + shapeCompile "github.com/viant/datly/repository/shape/compile" + shapeLoad "github.com/viant/datly/repository/shape/load" + "github.com/viant/datly/shared" + "github.com/viant/datly/view" + "github.com/viant/datly/view/state" + "gopkg.in/yaml.v3" +) + +func (s *Service) translateShape(ctx context.Context, opts *options.Options) error { + rule := opts.Rule() + compiler := shapeCompile.New() + loader := shapeLoad.New() + for rule.Index = 0; rule.Index < len(rule.Source); rule.Index++ { + sourceURL := rule.SourceURL() + _, name := url.Split(sourceURL, file.Scheme) + fmt.Printf("translating %v (shape)\n", name) + dql, err := rule.LoadSource(ctx, s.fs, sourceURL) + if err != nil { + return err + } + dql = strings.TrimSpace(dql) + if dql == "" { + return fmt.Errorf("source %s was empty", sourceURL) + } + shapeSource := &shape.Source{ + Name: strings.TrimSuffix(name, path.Ext(name)), + Path: url.Path(sourceURL), + DQL: dql, + Connector: strings.TrimSpace(rule.Connector), + } + planResult, err := compiler.Compile(ctx, shapeSource) + if err != nil { + return fmt.Errorf("failed to compile %s: %w", sourceURL, err) + } + componentArtifact, err := loader.LoadComponent(ctx, planResult) + if err != nil { + return fmt.Errorf("failed to load %s: %w", sourceURL, err) + } + component, ok := shapeLoad.ComponentFrom(componentArtifact) + if !ok { + return fmt.Errorf("unexpected component artifact for %s", sourceURL) + } + if err = s.persistShapeRoute(ctx, opts, sourceURL, dql, componentArtifact.Resource, component); err != nil { + return err + } + } + paths := url.Join(opts.Repository().RepositoryURL, "Datly", "routes", "paths.yaml") + if ok, _ := s.fs.Exists(ctx, paths); ok { + _ = s.fs.Delete(ctx, paths) + } + return nil +} + +type shapeRuleFile struct { + Resource *view.Resource `yaml:"Resource,omitempty"` + Routes []*repository.Component `yaml:"Routes,omitempty"` + TypeContext any `yaml:"TypeContext,omitempty"` +} + +func (s *Service) persistShapeRoute(ctx context.Context, opts *options.Options, sourceURL, dql string, resource *view.Resource, component *shapeLoad.Component) error { + rule := opts.Rule() + routeYAML, routeRoot, relDir, stem, err := routePathForShape(rule, opts.Repository().RepositoryURL, sourceURL) + if err != nil { + return err + } + if resource != nil { + for _, item := range resource.Views { + if item == nil || item.Template == nil { + continue + } + if strings.TrimSpace(item.Template.Source) == "" { + continue + } + sqlRel := strings.TrimSpace(item.Template.SourceURL) + if sqlRel == "" { + sqlRel = path.Join(stem, item.Name+".sql") + } + sqlDest := path.Join(routeRoot, relDir, filepath.ToSlash(sqlRel)) + if err = s.fs.Upload(ctx, sqlDest, file.DefaultFileOsMode, strings.NewReader(item.Template.Source)); err != nil { + return fmt.Errorf("failed to persist sql %s: %w", sqlDest, err) + } + item.Template.SourceURL = sqlRel + } + } + rootView := "" + if component != nil { + rootView = strings.TrimSpace(component.RootView) + } + if rootView == "" && resource != nil && len(resource.Views) > 0 && resource.Views[0] != nil { + rootView = resource.Views[0].Name + } + method, uri := parseShapeRulePath(dql, rule.RuleName(), opts.Repository().APIPrefix) + // Gap 3: RouteDirective overrides method/URI when explicitly declared in DQL. + if component != nil && component.Directives != nil && component.Directives.Route != nil { + rd := component.Directives.Route + if u := strings.TrimSpace(rd.URI); u != "" { + uri = u + } + if len(rd.Methods) > 0 { + if m := strings.TrimSpace(strings.ToUpper(rd.Methods[0])); m != "" { + method = m + } + } + } + route := &repository.Component{ + Path: contract.Path{ + Method: method, + URI: uri, + }, + Contract: contract.Contract{ + Service: serviceTypeForMethod(method), + }, + View: &view.View{Reference: shared.Reference{Ref: rootView}}, + } + if component != nil { + route.TypeContext = component.TypeContext + if component.Directives != nil && component.Directives.MCP != nil { + route.Name = strings.TrimSpace(component.Directives.MCP.Name) + route.Description = strings.TrimSpace(component.Directives.MCP.Description) + route.DescriptionURI = strings.TrimSpace(component.Directives.MCP.DescriptionPath) + } + } + if component != nil && (len(component.Input) > 0 || len(component.Meta) > 0) { + params := make(state.Parameters, 0, len(component.Input)+len(component.Meta)) + for _, s := range component.Input { + if s != nil { + p := s.Parameter + params = append(params, &p) + } + } + for _, s := range component.Meta { + if s != nil { + p := s.Parameter + params = append(params, &p) + } + } + if len(params) > 0 { + route.Contract.Input.Type.Parameters = params + } + } + payload := &shapeRuleFile{ + Resource: resource, + Routes: []*repository.Component{route}, + } + if component != nil && component.TypeContext != nil { + payload.TypeContext = component.TypeContext + } + data, err := yaml.Marshal(payload) + if err != nil { + return err + } + if err = s.fs.Upload(ctx, routeYAML, file.DefaultFileOsMode, strings.NewReader(string(data))); err != nil { + return fmt.Errorf("failed to persist route yaml %s: %w", routeYAML, err) + } + generateShapeTypes(url.Path(sourceURL), payload, component) + return nil +} + +func routePathForShape(rule *options.Rule, repoURL, sourceURL string) (routeYAML string, routeRoot string, relDir string, stem string, err error) { + sourcePath := filepath.Clean(url.Path(sourceURL)) + basePath := filepath.Clean(rule.BaseRuleURL()) + relative, relErr := filepath.Rel(basePath, sourcePath) + if relErr != nil || strings.HasPrefix(relative, "..") { + relative = filepath.Base(sourcePath) + } + relative = filepath.ToSlash(relative) + relDir = filepath.ToSlash(path.Dir(relative)) + if relDir == "." { + relDir = "" + } + stem = strings.TrimSuffix(path.Base(relative), path.Ext(relative)) + routeRoot = url.Join(repoURL, "Datly", "routes") + routeYAML = url.Join(routeRoot, relDir, stem+".yaml") + return routeYAML, routeRoot, relDir, stem, nil +} + +type shapeRuleHeader struct { + Method string `json:"Method"` + URI string `json:"URI"` +} + +func parseShapeRulePath(dql, ruleName, apiPrefix string) (string, string) { + method := "GET" + uri := "/" + strings.Trim(strings.TrimSpace(ruleName), "/") + if prefix := strings.TrimSpace(apiPrefix); prefix != "" { + uri = strings.TrimRight(prefix, "/") + uri + } + start := strings.Index(dql, "/*") + end := strings.Index(dql, "*/") + if start != -1 && end > start+2 { + raw := strings.TrimSpace(dql[start+2 : end]) + if strings.HasPrefix(raw, "{") && strings.HasSuffix(raw, "}") { + header := &shapeRuleHeader{} + if err := json.Unmarshal([]byte(raw), header); err == nil { + if candidate := strings.TrimSpace(strings.ToUpper(header.Method)); candidate != "" { + method = candidate + } + if candidate := strings.TrimSpace(header.URI); candidate != "" { + uri = candidate + } + } + } + } + return method, uri +} diff --git a/cmd/command/translate_shape_ir.go b/cmd/command/translate_shape_ir.go new file mode 100644 index 000000000..c2f37155f --- /dev/null +++ b/cmd/command/translate_shape_ir.go @@ -0,0 +1,168 @@ +package command + +import ( + "context" + "fmt" + "path" + "strings" + + "github.com/viant/afs/file" + "github.com/viant/afs/url" + "github.com/viant/datly/cmd/options" + "github.com/viant/datly/repository" + "github.com/viant/datly/repository/contract" + "github.com/viant/datly/repository/shape" + shapeCompile "github.com/viant/datly/repository/shape/compile" + "github.com/viant/datly/repository/shape/dql/ir" + dqlyaml "github.com/viant/datly/repository/shape/dql/render/yaml" + shapeLoad "github.com/viant/datly/repository/shape/load" + datlyservice "github.com/viant/datly/service" + "github.com/viant/datly/shared" + "github.com/viant/datly/view" + "github.com/viant/datly/view/state" + "gopkg.in/yaml.v3" +) + +func (s *Service) translateShapeIR(ctx context.Context, opts *options.Options) error { + rule := opts.Rule() + compiler := shapeCompile.New() + loader := shapeLoad.New() + for rule.Index = 0; rule.Index < len(rule.Source); rule.Index++ { + // Reuse legacy signature bootstrap so shape IR flow gets the same registry/signature context when available. + if err := s.ensureTranslator(opts); err == nil && s.translator != nil { + _ = s.translator.InitSignature(ctx, rule) + } + sourceURL := rule.SourceURL() + _, name := url.Split(sourceURL, file.Scheme) + fmt.Printf("translating %v (shape-ir)\n", name) + dql, err := rule.LoadSource(ctx, s.fs, sourceURL) + if err != nil { + return err + } + dql = strings.TrimSpace(dql) + if dql == "" { + return fmt.Errorf("source %s was empty", sourceURL) + } + shapeSource := &shape.Source{ + Name: strings.TrimSuffix(name, path.Ext(name)), + Path: url.Path(sourceURL), + DQL: dql, + Connector: strings.TrimSpace(rule.Connector), + } + planResult, err := compiler.Compile(ctx, shapeSource) + if err != nil { + return fmt.Errorf("failed to compile %s: %w", sourceURL, err) + } + componentArtifact, err := loader.LoadComponent(ctx, planResult) + if err != nil { + return fmt.Errorf("failed to load %s: %w", sourceURL, err) + } + component, ok := shapeLoad.ComponentFrom(componentArtifact) + if !ok { + return fmt.Errorf("unexpected component artifact for %s", sourceURL) + } + + payload, err := buildShapeRulePayload(opts, dql, componentArtifact.Resource, component) + if err != nil { + return err + } + routeYAML, err := yaml.Marshal(payload) + if err != nil { + return err + } + document, err := ir.FromYAML(routeYAML) + if err != nil { + return fmt.Errorf("failed to build IR from %s: %w", sourceURL, err) + } + encoded, err := dqlyaml.Encode(document) + if err != nil { + return fmt.Errorf("failed to encode IR for %s: %w", sourceURL, err) + } + + routeYAMLPath, _, _, _, err := routePathForShape(rule, opts.Repository().RepositoryURL, sourceURL) + if err != nil { + return err + } + irPath := strings.TrimSuffix(routeYAMLPath, ".yaml") + ".ir.yaml" + if err = s.fs.Upload(ctx, irPath, file.DefaultFileOsMode, strings.NewReader(string(encoded))); err != nil { + return fmt.Errorf("failed to persist route ir %s: %w", irPath, err) + } + generateShapeTypes(url.Path(sourceURL), payload, component) + } + return nil +} + +func buildShapeRulePayload(opts *options.Options, dql string, resource *view.Resource, component *shapeLoad.Component) (*shapeRuleFile, error) { + rule := opts.Rule() + rootView := "" + if component != nil { + rootView = strings.TrimSpace(component.RootView) + } + if rootView == "" && resource != nil && len(resource.Views) > 0 && resource.Views[0] != nil { + rootView = resource.Views[0].Name + } + method, uri := parseShapeRulePath(dql, rule.RuleName(), opts.Repository().APIPrefix) + // Gap 3: RouteDirective overrides method/URI when explicitly declared in DQL. + if component != nil && component.Directives != nil && component.Directives.Route != nil { + rd := component.Directives.Route + if u := strings.TrimSpace(rd.URI); u != "" { + uri = u + } + if len(rd.Methods) > 0 { + if m := strings.TrimSpace(strings.ToUpper(rd.Methods[0])); m != "" { + method = m + } + } + } + route := &repository.Component{ + Path: contract.Path{ + Method: method, + URI: uri, + }, + Contract: contract.Contract{ + Service: serviceTypeForMethod(method), + }, + View: &view.View{Reference: shared.Reference{Ref: rootView}}, + } + if component != nil { + route.TypeContext = component.TypeContext + if component.Directives != nil && component.Directives.MCP != nil { + route.Name = strings.TrimSpace(component.Directives.MCP.Name) + route.Description = strings.TrimSpace(component.Directives.MCP.Description) + route.DescriptionURI = strings.TrimSpace(component.Directives.MCP.DescriptionPath) + } + } + if component != nil && (len(component.Input) > 0 || len(component.Meta) > 0) { + params := make(state.Parameters, 0, len(component.Input)+len(component.Meta)) + for _, s := range component.Input { + if s != nil { + p := s.Parameter + params = append(params, &p) + } + } + for _, s := range component.Meta { + if s != nil { + p := s.Parameter + params = append(params, &p) + } + } + if len(params) > 0 { + route.Contract.Input.Type.Parameters = params + } + } + payload := &shapeRuleFile{ + Resource: resource, + Routes: []*repository.Component{route}, + } + if component != nil && component.TypeContext != nil { + payload.TypeContext = component.TypeContext + } + return payload, nil +} + +func serviceTypeForMethod(method string) datlyservice.Type { + if strings.EqualFold(method, "GET") { + return datlyservice.TypeReader + } + return datlyservice.TypeExecutor +} diff --git a/cmd/command/translate_shape_test.go b/cmd/command/translate_shape_test.go new file mode 100644 index 000000000..b76fac4eb --- /dev/null +++ b/cmd/command/translate_shape_test.go @@ -0,0 +1,30 @@ +package command + +import ( + "path/filepath" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/viant/datly/cmd/options" +) + +func TestParseShapeRulePath(t *testing.T) { + method, uri := parseShapeRulePath(`/* {"Method":"POST","URI":"/v1/api/orders"} */ SELECT 1`, "orders", "/v1/api") + assert.Equal(t, "POST", method) + assert.Equal(t, "/v1/api/orders", uri) + + method, uri = parseShapeRulePath(`SELECT 1`, "orders", "/v1/api") + assert.Equal(t, "GET", method) + assert.Equal(t, "/v1/api/orders", uri) +} + +func TestRoutePathForShape(t *testing.T) { + rule := &options.Rule{Project: "/repo", Source: []string{"/repo/dql/platform/campaign/post.dql"}} + routeYAML, routeRoot, relDir, stem, err := routePathForShape(rule, "/repo/dev", "/repo/dql/platform/campaign/post.dql") + require.NoError(t, err) + assert.Equal(t, "/repo/dev/Datly/routes/platform/campaign/post.yaml", routeYAML) + assert.Equal(t, "/repo/dev/Datly/routes", routeRoot) + assert.Equal(t, filepath.ToSlash("platform/campaign"), relDir) + assert.Equal(t, "post", stem) +} diff --git a/cmd/command/translate_shape_xgen.go b/cmd/command/translate_shape_xgen.go new file mode 100644 index 000000000..3827ac75f --- /dev/null +++ b/cmd/command/translate_shape_xgen.go @@ -0,0 +1,100 @@ +package command + +import ( + "fmt" + "os" + "path/filepath" + "strings" + + dqlir "github.com/viant/datly/repository/shape/dql/ir" + dqlshape "github.com/viant/datly/repository/shape/dql/shape" + shapeLoad "github.com/viant/datly/repository/shape/load" + "github.com/viant/datly/repository/shape/typectx" + "github.com/viant/datly/repository/shape/xgen" + "gopkg.in/yaml.v3" +) + +// generateShapeTypes emits a Go type file (shapes_gen.go) for the compiled +// component. It is a best-effort step: any failure is logged as a warning so +// the route YAML is still written successfully. +// +// Normal flow — types already exist in pkg/shapes_gen.go: +// +// xgen merges the file, updating only the types produced by this DQL. +// +// Backfill flow — no types file yet: +// +// xgen generates stub types from the statically-inferred columns +// (explicit SELECT columns give accurate field names/types; SELECT * +// produces a minimal stub that the user should refine or regenerate +// after DB discovery). +func generateShapeTypes(sourceAbsPath string, payload *shapeRuleFile, component *shapeLoad.Component) { + if component == nil || component.TypeContext == nil { + return + } + ctx := component.TypeContext + if strings.TrimSpace(ctx.PackageDir) == "" { + return + } + + projectDir := findProjectDir(sourceAbsPath) + if projectDir == "" { + fmt.Printf("WARNING: shape xgen: cannot locate go.mod from %s, skipping type generation\n", sourceAbsPath) + return + } + + packageDir := strings.TrimSpace(ctx.PackageDir) + if !filepath.IsAbs(packageDir) { + packageDir = filepath.Join(projectDir, packageDir) + } + + data, err := yaml.Marshal(payload) + if err != nil { + fmt.Printf("WARNING: shape xgen: marshal failed for %s: %v\n", sourceAbsPath, err) + return + } + doc, err := dqlir.FromYAML(data) + if err != nil { + fmt.Printf("WARNING: shape xgen: IR parse failed for %s: %v\n", sourceAbsPath, err) + return + } + + shapeDoc := buildShapeDocument(doc, ctx) + cfg := &xgen.Config{ + ProjectDir: projectDir, + PackageDir: packageDir, + PackageName: strings.TrimSpace(ctx.PackageName), + PackagePath: strings.TrimSpace(ctx.PackagePath), + } + + result, err := xgen.GenerateFromDQLShape(shapeDoc, cfg) + if err != nil { + fmt.Printf("WARNING: shape xgen: type generation skipped for %s: %v\n", filepath.Base(sourceAbsPath), err) + return + } + fmt.Printf("generated types %s → %s\n", strings.Join(result.Types, ", "), result.FilePath) +} + +// buildShapeDocument bridges an ir.Document into the shape.Document expected by xgen. +func buildShapeDocument(doc *dqlir.Document, ctx *typectx.Context) *dqlshape.Document { + return &dqlshape.Document{ + Root: doc.Root, + TypeContext: ctx, + } +} + +// findProjectDir walks up from sourcePath until it finds a directory containing +// go.mod, returning that directory. Returns "" when no go.mod is found. +func findProjectDir(sourcePath string) string { + dir := filepath.Dir(filepath.Clean(sourcePath)) + for { + if _, err := os.Stat(filepath.Join(dir, "go.mod")); err == nil { + return dir + } + parent := filepath.Dir(dir) + if parent == dir { + return "" + } + dir = parent + } +} diff --git a/cmd/datly/build.yaml b/cmd/datly/build.yaml index 9b2fefa03..92ae41d35 100644 --- a/cmd/datly/build.yaml +++ b/cmd/datly/build.yaml @@ -9,7 +9,7 @@ pipeline: set_sdk: action: sdk.set target: $target - sdk: go:1.23 + sdk: go:1.25.1 build: action: exec:run target: $target diff --git a/cmd/options/query_test.go b/cmd/options/query_test.go new file mode 100644 index 000000000..055d3915d --- /dev/null +++ b/cmd/options/query_test.go @@ -0,0 +1,34 @@ +package options + +import ( + "testing" + + "github.com/stretchr/testify/require" + "github.com/viant/datly/internal/testutil/sqlnormalizer" + "github.com/viant/sqlparser" +) + +func parserOption() sqlparser.Option { + return sqlparser.WithErrorHandler(nil) +} + +func TestRule_NormalizeSQL(t *testing.T) { + for _, testCase := range sqlnormalizer.Cases() { + t.Run(testCase.Name, func(t *testing.T) { + rule := &Rule{Generated: testCase.Generated} + actual := rule.NormalizeSQL(testCase.SQL, parserOption) + require.Equal(t, testCase.Expect, actual) + }) + } +} + +func TestMapper_Map(t *testing.T) { + m := mapper{"a": "A"} + require.Equal(t, "A", m.Map("a")) + require.Equal(t, "b", m.Map("b")) +} + +func TestNormalizeName(t *testing.T) { + require.Equal(t, "UserAlias", normalizeName("user_alias")) + require.Equal(t, "UserAlias", normalizeName("UserAlias")) +} diff --git a/cmd/options/rule.go b/cmd/options/rule.go index 4528e972f..4d43015c0 100644 --- a/cmd/options/rule.go +++ b/cmd/options/rule.go @@ -22,6 +22,7 @@ type Rule struct { Name string `short:"n" long:"name" description:"rule name"` ModulePrefix string `short:"u" long:"namespace" description:"rule uri/namespace" default:"dev" ` Source []string `short:"s" long:"src" description:"source"` + Engine string `long:"engine" description:"translation engine" choice:"internal" choice:"legacy" choice:"shape" choice:"shape-ir"` Packages []string `short:"g" long:"pkg" description:"entity package"` Output []string Index int @@ -33,6 +34,29 @@ type Rule struct { IncludePredicates bool `short:"K" long:"inclPred" description:"generate predicate code" ` } +const ( + EngineInternal = "internal" + EngineLegacy = "legacy" // alias of internal, kept for migration compatibility + EngineShape = "shape" + EngineShapeIR = "shape-ir" +) + +func (r *Rule) EffectiveEngine() string { + engine := strings.ToLower(strings.TrimSpace(r.Engine)) + switch engine { + case "", EngineInternal, EngineLegacy: + return EngineInternal + case "shapeir": + return EngineShapeIR + case EngineShapeIR: + return EngineShapeIR + case EngineShape: + return EngineShape + default: + return EngineInternal + } +} + // Module returns go module func (r *Rule) Module() (*modfile.Module, error) { if r.module != nil { diff --git a/cmd/options/rule_engine_test.go b/cmd/options/rule_engine_test.go new file mode 100644 index 000000000..08f27bdca --- /dev/null +++ b/cmd/options/rule_engine_test.go @@ -0,0 +1,25 @@ +package options + +import "testing" + +func TestRule_EffectiveEngine(t *testing.T) { + testCases := []struct { + name string + engine string + want string + }{ + {name: "default", engine: "", want: EngineInternal}, + {name: "internal", engine: "internal", want: EngineInternal}, + {name: "legacy alias", engine: "legacy", want: EngineInternal}, + {name: "shape", engine: "shape", want: EngineShape}, + {name: "shape ir", engine: "shape-ir", want: EngineShapeIR}, + {name: "shape ir alias", engine: "shapeir", want: EngineShapeIR}, + {name: "invalid", engine: "other", want: EngineInternal}, + } + for _, testCase := range testCases { + rule := &Rule{Engine: testCase.engine} + if got := rule.EffectiveEngine(); got != testCase.want { + t.Fatalf("%s: got %s, want %s", testCase.name, got, testCase.want) + } + } +} diff --git a/doc/example_test.go b/doc/example_test.go index 8add98241..eaebcd3c8 100644 --- a/doc/example_test.go +++ b/doc/example_test.go @@ -39,8 +39,8 @@ type Validation struct { IsValid bool } -// Example_ComponentDebugging show how to programmatically execute executor rule -func Example_ComponentDebugging() { +// Example shows how to programmatically execute executor rule. +func Example() { //Uncomment various additional debugging and troubleshuting // expand.SetPanicOnError(false) // read.ShowSQL(true) diff --git a/doc/extension/EXAMPLES.md b/doc/extension/EXAMPLES.md index 1e6d6614a..6f476a0a7 100644 --- a/doc/extension/EXAMPLES.md +++ b/doc/extension/EXAMPLES.md @@ -2200,128 +2200,7 @@ go 1.21 require ( github.com/aerospike/aerospike-client-go v4.5.2+incompatible - github.com/aws/aws-lambda-go v1.31.0 - github.com/francoispqt/gojay v1.2.13 - github.com/go-playground/universal-translator v0.18.0 // indirect - github.com/go-playground/validator v9.31.0+incompatible - github.com/go-sql-driver/mysql v1.7.0 - github.com/goccy/go-json v0.9.11 - github.com/golang-jwt/jwt/v4 v4.4.1 - github.com/google/gops v0.3.23 - github.com/google/uuid v1.3.0 - github.com/jessevdk/go-flags v1.5.0 - github.com/leodido/go-urn v1.2.1 // indirect - github.com/lib/pq v1.10.6 - github.com/mattn/go-sqlite3 v1.14.16 - github.com/onsi/gomega v1.20.2 // indirect - github.com/pkg/errors v0.9.1 - github.com/stretchr/testify v1.8.4 - github.com/viant/afs v1.24.2 - github.com/viant/afsc v1.9.0 - github.com/viant/assertly v0.9.1-0.20220620174148-bab013f93a60 - github.com/viant/bigquery v0.2.1 - github.com/viant/cloudless v1.8.1 - github.com/viant/dsc v0.16.2 // indirect - github.com/viant/dsunit v0.10.8 - github.com/viant/dyndb v0.1.4-0.20221214043424-27654ab6ed9c - github.com/viant/gmetric v0.2.7-0.20220508155136-c2e3c95db446 - github.com/viant/godiff v0.4.1 - github.com/viant/parsly v0.2.0 - github.com/viant/pgo v0.10.3 - github.com/viant/scy v0.6.0 - github.com/viant/sqlx v0.8.0 - github.com/viant/structql v0.2.2 - github.com/viant/toolbox v0.34.6-0.20221112031702-3e7cdde7f888 - github.com/viant/velty v0.2.0 - github.com/viant/xdatly/types/custom v0.0.0-20230309034540-231985618fc7 - github.com/viant/xreflect v0.0.0-20230303201326-f50afb0feb0d - github.com/viant/xunsafe v0.8.4 - golang.org/x/mod v0.9.0 - golang.org/x/oauth2 v0.7.0 - google.golang.org/api v0.114.0 - gopkg.in/go-playground/assert.v1 v1.2.1 // indirect - gopkg.in/yaml.v3 v3.0.1 -) - -require ( - github.com/viant/govalidator v0.2.1 - github.com/viant/sqlparser v0.3.1-0.20230320162628-96274e82953f - golang.org/x/crypto v0.7.0 // indirect -) - -require ( - github.com/aws/aws-sdk-go v1.44.12 - github.com/aws/aws-sdk-go-v2/config v1.18.3 - github.com/aws/aws-sdk-go-v2/service/s3 v1.33.1 - github.com/viant/structology v0.2.0 - github.com/viant/xdatly/extension v0.0.0-20230323215422-3e5c3147f0e6 - github.com/viant/xdatly/handler v0.0.0-20230619231115-e622dd6aff79 - github.com/viant/xdatly/types/core v0.0.0-20230615201419-f5e46b6b011f -) - -require ( - cloud.google.com/go v0.110.0 // indirect - cloud.google.com/go/compute v1.19.0 // indirect - cloud.google.com/go/compute/metadata v0.2.3 // indirect - cloud.google.com/go/iam v0.13.0 // indirect - cloud.google.com/go/secretmanager v1.10.0 // indirect - cloud.google.com/go/storage v1.29.0 // indirect - github.com/aws/aws-sdk-go-v2 v1.18.0 // indirect - github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.4.10 // indirect - github.com/aws/aws-sdk-go-v2/credentials v1.13.3 // indirect - github.com/aws/aws-sdk-go-v2/feature/dynamodb/attributevalue v1.10.7 // indirect - github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.12.19 // indirect - github.com/aws/aws-sdk-go-v2/internal/configsources v1.1.33 // indirect - github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.4.27 // indirect - github.com/aws/aws-sdk-go-v2/internal/ini v1.3.26 // indirect - github.com/aws/aws-sdk-go-v2/internal/v4a v1.0.25 // indirect - github.com/aws/aws-sdk-go-v2/service/dynamodb v1.17.8 // indirect - github.com/aws/aws-sdk-go-v2/service/dynamodbstreams v1.13.27 // indirect - github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.9.11 // indirect - github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.1.28 // indirect - github.com/aws/aws-sdk-go-v2/service/internal/endpoint-discovery v1.7.20 // indirect - github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.9.27 // indirect - github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.14.2 // indirect - github.com/aws/aws-sdk-go-v2/service/sns v1.20.11 // indirect - github.com/aws/aws-sdk-go-v2/service/sqs v1.22.0 // indirect - github.com/aws/aws-sdk-go-v2/service/sso v1.11.25 // indirect - github.com/aws/aws-sdk-go-v2/service/ssooidc v1.13.8 // indirect - github.com/aws/aws-sdk-go-v2/service/sts v1.17.5 // indirect - github.com/aws/smithy-go v1.13.5 // indirect - github.com/davecgh/go-spew v1.1.1 // indirect - github.com/decred/dcrd/dcrec/secp256k1/v4 v4.0.0-20210816181553-5444fa50b93d // indirect - github.com/go-errors/errors v1.4.2 // indirect - github.com/go-playground/locales v0.14.0 // indirect - github.com/golang/groupcache v0.0.0-20200121045136-8c9f03a8e57e // indirect - github.com/golang/protobuf v1.5.3 // indirect - github.com/google/go-cmp v0.5.9 // indirect - github.com/googleapis/enterprise-certificate-proxy v0.2.3 // indirect - github.com/googleapis/gax-go/v2 v2.8.0 // indirect - github.com/jmespath/go-jmespath v0.4.0 // indirect - github.com/kr/pretty v0.3.0 // indirect - github.com/lestrrat-go/backoff/v2 v2.0.8 // indirect - github.com/lestrrat-go/blackmagic v1.0.0 // indirect - github.com/lestrrat-go/httpcc v1.0.1 // indirect - github.com/lestrrat-go/iter v1.0.1 // indirect - github.com/lestrrat-go/jwx v1.2.25 // indirect - github.com/lestrrat-go/option v1.0.0 // indirect - github.com/michael/mymodule2 v0.0.0-00010101000000-000000000000 // indirect - github.com/nxadm/tail v1.4.8 // indirect - github.com/pmezard/go-difflib v1.0.0 // indirect - github.com/rogpeppe/go-internal v1.9.0 // indirect - github.com/viant/igo v0.1.0 // indirect - github.com/yuin/gopher-lua v0.0.0-20221210110428-332342483e3f // indirect - go.opencensus.io v0.24.0 // indirect - golang.org/x/net v0.9.0 // indirect - golang.org/x/sync v0.1.0 // indirect - golang.org/x/sys v0.7.0 // indirect - golang.org/x/term v0.7.0 // indirect - golang.org/x/text v0.9.0 // indirect - golang.org/x/xerrors v0.0.0-20220907171357-04be3eba64a2 // indirect - google.golang.org/appengine v1.6.7 // indirect - google.golang.org/genproto v0.0.0-20230410155749-daa745c078e1 // indirect - google.golang.org/grpc v1.54.0 // indirect - google.golang.org/protobuf v1.30.0 // indirect + .... gopkg.in/yaml.v2 v2.4.0 // indirect ) diff --git a/e2e/local/build.yaml b/e2e/local/build.yaml index b9b3f87b3..bb6e6bbe3 100644 --- a/e2e/local/build.yaml +++ b/e2e/local/build.yaml @@ -14,7 +14,7 @@ pipeline: set_sdk: action: sdk.set target: $target - sdk: go:1.23 + sdk: go:1.25.5 buildValidator: action: exec:run diff --git a/e2e/local/regression/cases/001_one_to_many/expect_2.txt b/e2e/local/regression/cases/001_one_to_many/expect_2.txt index 2292a8c77..7796339fb 100644 --- a/e2e/local/regression/cases/001_one_to_many/expect_2.txt +++ b/e2e/local/regression/cases/001_one_to_many/expect_2.txt @@ -18,7 +18,7 @@ type GeneratedStruct struct { type Products struct { Id int `sqlx:"ID" velty:"names=ID|Id"` Name *string `sqlx:"NAME" velty:"names=NAME|Name"` - VendorId *int `sqlx:"VENDOR_ID" velty:"names=VENDOR_ID|VendorId"` + VendorId *int `sqlx:"VENDOR_ID" internal:"true" velty:"names=VENDOR_ID|VendorId"` Status *int `sqlx:"STATUS" velty:"names=STATUS|Status"` Created *time.Time `sqlx:"CREATED" velty:"names=CREATED|Created"` UserCreated *int `sqlx:"USER_CREATED" velty:"names=USER_CREATED|UserCreated"` diff --git a/e2e/local/regression/regression.yaml b/e2e/local/regression/regression.yaml index f4ce9a18c..a8b575fae 100644 --- a/e2e/local/regression/regression.yaml +++ b/e2e/local/regression/regression.yaml @@ -2,10 +2,10 @@ init: v1: abc v2: def pipeline: - set_sdk: - action: sdk.set - target: $target - sdk: go:1.23 +# set_sdk: +# action: sdk.set +# target: $target +# sdk: go:1.25.1 database: action: run @@ -30,7 +30,7 @@ pipeline: '[]gen': '@gen' subPath: 'cases/${index}_*' - #range: 1..007 + range: 11..020 template: checkSkip: action: nop @@ -39,5 +39,3 @@ pipeline: test: action: run request: '@test' - - diff --git a/e2e/local/regression/rule.yaml b/e2e/local/regression/rule.yaml index f43269480..5d6f0a3ce 100644 --- a/e2e/local/regression/rule.yaml +++ b/e2e/local/regression/rule.yaml @@ -10,6 +10,7 @@ pipeline: commands: - mkdir -p ${appPath}/e2e/local/autogen - rm -rf ${appPath}/e2e/local/autogen + - rm -f ${appPath}/e2e/local/regression/paths.yaml loop: diff --git a/e2e/mcp/debug.go b/e2e/mcp/debug.go index 701bd8239..0a0ce60d6 100644 --- a/e2e/mcp/debug.go +++ b/e2e/mcp/debug.go @@ -3,6 +3,9 @@ package main import ( "context" "fmt" + "github.com/viant/jsonrpc/transport/client/stdio" + "github.com/viant/mcp-protocol/schema" + "github.com/viant/mcp/client" "github.com/viant/toolbox" "log" "path/filepath" @@ -25,10 +28,11 @@ func main() { fmt.Println(args) fmt.Println("Starting MCP client with args:", datlyBin+strings.Join(args, " ")) - c, err := client.NewStdioMCPClient(datlyBin, []string{}, args...) + transport, err := stdio.New(datlyBin, stdio.WithArguments(strings.Join(args, " "))) if err != nil { - log.Fatalf("Failed to create client: %v", err) + log.Fatalf("Failed to create stdio transport: %v", err) } + c := client.New("datly-debug", "0.1", transport) defer c.Close() // Create context with timeout @@ -37,14 +41,7 @@ func main() { // Initialize the client fmt.Println("Initializing client...") - initRequest := mcp.InitializeRequest{} - initRequest.Params.ProtocolVersion = mcp.LATEST_PROTOCOL_VERSION - initRequest.Params.ClientInfo = mcp.Implementation{ - Name: "example-client", - Version: "1.0.0", - } - - initResult, err := c.Initialize(ctx, initRequest) + initResult, err := c.Initialize(ctx) if err != nil { log.Fatalf("Failed to initialize: %v", err) } @@ -54,26 +51,20 @@ func main() { initResult.ServerInfo.Version, ) - readRequest := mcp.ReadResourceRequest{ - Request: mcp.Request{ - Method: string(mcp.MethodResourcesRead), - }, - } - readRequest.Params.URI = "datly://localhost/v1/api/dev/vendors/{vendorID}" - readRequest.Params.Arguments = map[string]interface{}{ - "vendorID": "12345", // Example vendor ID to read - } - - c.ReadResource(ctx, readRequest) // ensure the client is initialized before proceeding + readRequest := &schema.ReadResourceRequestParams{Uri: "datly://localhost/v1/api/dev/vendors/12345"} + _, _ = c.ReadResource(ctx, readRequest) // ensure the client is initialized before proceeding // List Tools fmt.Println("Listing available tools...") - toolsRequest := mcp.ListResourceTemplatesRequest{} - tools, err := c.ListResourceTemplates(ctx, toolsRequest) + tools, err := c.ListResourceTemplates(ctx, nil) if err != nil { log.Fatalf("Failed to list tools: %v", err) } for _, tool := range tools.ResourceTemplates { - fmt.Printf("- %s: %s\n", tool.Name, tool.Description) + desc := "" + if tool.Description != nil { + desc = *tool.Description + } + fmt.Printf("- %s: %s\n", tool.Name, desc) } } diff --git a/gateway/async.go b/gateway/async.go index 2ef1b2c83..d7228a14c 100644 --- a/gateway/async.go +++ b/gateway/async.go @@ -76,7 +76,7 @@ func (r *Service) watchAsyncJob(ctx context.Context) { err = fs.Move(ctx, object.URL(), destURL) } if err != nil { - log.Println(err) + log.Printf("datly async post-process failed: source=%q err=%v", object.URL(), err) } } else { diff --git a/gateway/config.go b/gateway/config.go index 5aff4b253..9aedccb0e 100644 --- a/gateway/config.go +++ b/gateway/config.go @@ -29,6 +29,7 @@ type ( ExposableConfig struct { APIPrefix string //like /v1/api/ RouteURL string + DQLBootstrap *DQLBootstrap ContentURL string PluginsURL string DependencyURL string @@ -63,6 +64,25 @@ type ( RetryIntervalInS int _retry time.Duration } + + DQLBootstrap struct { + Sources []string + Exclude []string + FailFast *bool + Precedence string + CompileProfile string + MixedMode string + UnknownNonReadMode string + ColumnDiscoveryMode string + DQLPathMarker string + RoutesRelativePath string + } +) + +const ( + DQLBootstrapPrecedenceRoutesWins = "routes_wins" + DQLBootstrapPrecedenceDQLWins = "dql_wins" + DQLBootstrapPrecedenceErrorOnMixed = "error_on_conflict" ) func (d *ChangeDetection) Init() { @@ -78,12 +98,38 @@ func (d *ChangeDetection) Init() { } func (c *Config) Validate() error { - if c.RouteURL == "" { + if c.DQLBootstrap != nil && len(c.DQLBootstrap.Sources) == 0 { + return fmt.Errorf("DQLBootstrap.Sources was empty") + } + if c.RouteURL == "" && !c.hasDQLBootstrap() { return fmt.Errorf("RouteURL was empty") } return nil } +func (c *Config) hasDQLBootstrap() bool { + return c != nil && c.DQLBootstrap != nil && len(c.DQLBootstrap.Sources) > 0 +} + +func (d *DQLBootstrap) ShouldFailFast() bool { + if d == nil || d.FailFast == nil { + return true + } + return *d.FailFast +} + +func (d *DQLBootstrap) EffectivePrecedence() string { + if d == nil { + return DQLBootstrapPrecedenceRoutesWins + } + switch strings.TrimSpace(strings.ToLower(d.Precedence)) { + case DQLBootstrapPrecedenceRoutesWins, DQLBootstrapPrecedenceDQLWins, DQLBootstrapPrecedenceErrorOnMixed: + return strings.TrimSpace(strings.ToLower(d.Precedence)) + default: + return DQLBootstrapPrecedenceRoutesWins + } +} + func (c *Config) Discovery() bool { return c.AutoDiscovery == nil || *c.AutoDiscovery } diff --git a/gateway/dql_bootstrap.go b/gateway/dql_bootstrap.go new file mode 100644 index 000000000..fe7d62dbc --- /dev/null +++ b/gateway/dql_bootstrap.go @@ -0,0 +1,453 @@ +package gateway + +import ( + "context" + "encoding/json" + "fmt" + "os" + "path" + "path/filepath" + "sort" + "strings" + + "github.com/viant/datly/repository" + "github.com/viant/datly/repository/contract" + "github.com/viant/datly/repository/shape" + shapeCompile "github.com/viant/datly/repository/shape/compile" + shapeLoad "github.com/viant/datly/repository/shape/load" + datlyservice "github.com/viant/datly/service" + "github.com/viant/datly/view" +) + +func (r *Service) applyDQLBootstrap(ctx context.Context, repo *repository.Service, cfg *DQLBootstrap) error { + if cfg == nil || len(cfg.Sources) == 0 { + return nil + } + sources, err := discoverDQLBootstrapSources(cfg.Sources, cfg.Exclude) + if err != nil { + return err + } + if len(sources) == 0 { + return fmt.Errorf("no DQL bootstrap sources matched") + } + compiler := shapeCompile.New() + loader := shapeLoad.New() + precedence := cfg.EffectivePrecedence() + var errors []error + for _, sourcePath := range sources { + component, err := compileBootstrapComponent(ctx, compiler, loader, repo, sourcePath, cfg, r.Config.APIPrefix) + if err != nil { + if cfg.ShouldFailFast() { + return err + } + errors = append(errors, err) + continue + } + exists, lookupErr := hasRepositoryProvider(ctx, repo, &component.Path) + if lookupErr != nil { + if cfg.ShouldFailFast() { + return lookupErr + } + errors = append(errors, lookupErr) + continue + } + if exists { + switch precedence { + case DQLBootstrapPrecedenceRoutesWins: + continue + case DQLBootstrapPrecedenceErrorOnMixed: + err = fmt.Errorf("DQL bootstrap conflict for %s:%s", component.Method, component.URI) + if cfg.ShouldFailFast() { + return err + } + errors = append(errors, err) + continue + } + } + repo.Register(component) + } + if len(errors) > 0 { + return fmt.Errorf("DQL bootstrap completed with %d errors: %w", len(errors), errors[0]) + } + return nil +} + +func compileBootstrapComponent(ctx context.Context, compiler *shapeCompile.DQLCompiler, loader *shapeLoad.Loader, repo *repository.Service, sourcePath string, cfg *DQLBootstrap, apiPrefix string) (*repository.Component, error) { + data, err := os.ReadFile(sourcePath) + if err != nil { + return nil, fmt.Errorf("failed to read DQL bootstrap source %s: %w", sourcePath, err) + } + dql := strings.TrimSpace(string(data)) + if dql == "" { + return nil, fmt.Errorf("empty DQL bootstrap source: %s", sourcePath) + } + sourceName := strings.TrimSuffix(filepath.Base(sourcePath), filepath.Ext(sourcePath)) + source := &shape.Source{ + Name: sourceName, + Path: sourcePath, + DQL: dql, + } + planResult, err := compiler.Compile(ctx, source, compileOptionsFromBootstrap(cfg)...) + if err != nil { + return nil, fmt.Errorf("failed to compile DQL bootstrap source %s: %w", sourcePath, err) + } + componentArtifact, err := loader.LoadComponent(ctx, planResult) + if err != nil { + return nil, fmt.Errorf("failed to load DQL bootstrap source %s: %w", sourcePath, err) + } + normalizeBootstrapInlineSQL(componentArtifact.Resource) + mergeBootstrapSharedResources(componentArtifact.Resource, repo) + loaded, ok := componentArtifact.Component.(*shapeLoad.Component) + if !ok || loaded == nil { + return nil, fmt.Errorf("unexpected shape component artifact for %s", sourcePath) + } + rootView := lookupRootView(componentArtifact.Resource, loaded.RootView) + if rootView == nil { + return nil, fmt.Errorf("missing root view %q for %s", loaded.RootView, sourcePath) + } + method, uri := resolvePathSettings(sourcePath, dql, apiPrefix) + componentModel := &repository.Component{ + Path: contract.Path{ + Method: method, + URI: uri, + }, + Contract: contract.Contract{ + Service: defaultServiceForMethod(method, rootView), + }, + View: rootView, + TypeContext: loaded.TypeContext, + } + loadOptions := []repository.Option{} + if repo != nil { + loadOptions = append(loadOptions, repository.WithResources(repo.Resources())) + loadOptions = append(loadOptions, repository.WithExtensions(repo.Extensions())) + } + components, err := repository.LoadComponentsFromMap(ctx, map[string]any{ + "Resource": componentArtifact.Resource, + "Components": []*repository.Component{componentModel}, + }, loadOptions...) + if err != nil { + return nil, fmt.Errorf("failed to materialize bootstrap component for %s: %w", sourcePath, err) + } + if err = components.Init(ctx); err != nil { + return nil, fmt.Errorf("failed to initialize bootstrap component for %s: %w", sourcePath, err) + } + if len(components.Components) == 0 || components.Components[0] == nil { + return nil, fmt.Errorf("empty initialized bootstrap component for %s", sourcePath) + } + return components.Components[0], nil +} + +func mergeBootstrapSharedResources(target *view.Resource, repo *repository.Service) { + if target == nil || repo == nil || repo.Resources() == nil { + return + } + if connectors, err := repo.Resources().Lookup(view.ResourceConnectors); err == nil && connectors != nil && connectors.Resource != nil { + target.MergeFrom(connectors.Resource, nil) + } + if constants, err := repo.Resources().Lookup(view.ResourceConstants); err == nil && constants != nil && constants.Resource != nil { + target.MergeFrom(constants.Resource, nil) + } +} + +func normalizeBootstrapInlineSQL(resource *view.Resource) { + if resource == nil { + return + } + for _, item := range resource.Views { + if item == nil || item.Template == nil { + continue + } + // DQL bootstrap compiles from in-memory source; keep SQL inline and avoid filesystem lookups. + item.Template.SourceURL = "" + } +} + +func defaultServiceForMethod(method string, rootView *view.View) datlyservice.Type { + if strings.EqualFold(method, "GET") { + return datlyservice.TypeReader + } + if rootView != nil && rootView.Mode == view.ModeQuery { + return datlyservice.TypeReader + } + return datlyservice.TypeExecutor +} + +func hasRepositoryProvider(ctx context.Context, repo *repository.Service, path *contract.Path) (bool, error) { + if repo == nil || repo.Registry() == nil || path == nil { + return false, nil + } + _, err := repo.Registry().LookupProvider(ctx, path) + if err != nil { + message := strings.ToLower(strings.TrimSpace(err.Error())) + if strings.Contains(message, "not found") { + return false, nil + } + return false, err + } + return true, nil +} + +func compileOptionsFromBootstrap(cfg *DQLBootstrap) []shape.CompileOption { + if cfg == nil { + return nil + } + var result []shape.CompileOption + switch strings.ToLower(strings.TrimSpace(cfg.CompileProfile)) { + case string(shape.CompileProfileStrict): + result = append(result, shape.WithCompileProfile(shape.CompileProfileStrict)) + case string(shape.CompileProfileCompat): + result = append(result, shape.WithCompileProfile(shape.CompileProfileCompat)) + } + switch strings.ToLower(strings.TrimSpace(cfg.MixedMode)) { + case string(shape.CompileMixedModeExecWins): + result = append(result, shape.WithMixedMode(shape.CompileMixedModeExecWins)) + case string(shape.CompileMixedModeReadWins): + result = append(result, shape.WithMixedMode(shape.CompileMixedModeReadWins)) + case string(shape.CompileMixedModeErrorOnMixed): + result = append(result, shape.WithMixedMode(shape.CompileMixedModeErrorOnMixed)) + } + switch strings.ToLower(strings.TrimSpace(cfg.UnknownNonReadMode)) { + case string(shape.CompileUnknownNonReadWarn): + result = append(result, shape.WithUnknownNonReadMode(shape.CompileUnknownNonReadWarn)) + case string(shape.CompileUnknownNonReadError): + result = append(result, shape.WithUnknownNonReadMode(shape.CompileUnknownNonReadError)) + } + switch strings.ToLower(strings.TrimSpace(cfg.ColumnDiscoveryMode)) { + case string(shape.CompileColumnDiscoveryAuto): + result = append(result, shape.WithColumnDiscoveryMode(shape.CompileColumnDiscoveryAuto)) + case string(shape.CompileColumnDiscoveryOn): + result = append(result, shape.WithColumnDiscoveryMode(shape.CompileColumnDiscoveryOn)) + case string(shape.CompileColumnDiscoveryOff): + result = append(result, shape.WithColumnDiscoveryMode(shape.CompileColumnDiscoveryOff)) + } + if marker := strings.TrimSpace(cfg.DQLPathMarker); marker != "" { + result = append(result, shape.WithDQLPathMarker(marker)) + } + if rel := strings.TrimSpace(cfg.RoutesRelativePath); rel != "" { + result = append(result, shape.WithRoutesRelativePath(rel)) + } + return result +} + +func discoverDQLBootstrapSources(includes, excludes []string) ([]string, error) { + seen := map[string]struct{}{} + var result []string + for _, include := range includes { + include = strings.TrimSpace(include) + if include == "" { + continue + } + expanded, err := expandBootstrapPattern(include) + if err != nil { + return nil, err + } + for _, candidate := range expanded { + if !isDQLSourceFile(candidate) { + continue + } + if matchesAnyPattern(candidate, excludes) { + continue + } + if _, ok := seen[candidate]; ok { + continue + } + seen[candidate] = struct{}{} + result = append(result, candidate) + } + } + sort.Strings(result) + return result, nil +} + +func expandBootstrapPattern(pattern string) ([]string, error) { + pattern = filepath.Clean(pattern) + if strings.Contains(pattern, "**") { + return expandDoubleStarPattern(pattern) + } + if hasGlobMeta(pattern) { + matches, err := filepath.Glob(pattern) + if err != nil { + return nil, err + } + return flattenPaths(matches) + } + return flattenPaths([]string{pattern}) +} + +func flattenPaths(items []string) ([]string, error) { + var result []string + for _, item := range items { + item = strings.TrimSpace(item) + if item == "" { + continue + } + info, err := os.Stat(item) + if err != nil { + if os.IsNotExist(err) { + continue + } + return nil, err + } + if !info.IsDir() { + result = append(result, item) + continue + } + err = filepath.WalkDir(item, func(candidate string, d os.DirEntry, walkErr error) error { + if walkErr != nil { + return walkErr + } + if d.IsDir() { + return nil + } + if isDQLSourceFile(candidate) { + result = append(result, candidate) + } + return nil + }) + if err != nil { + return nil, err + } + } + return result, nil +} + +func expandDoubleStarPattern(pattern string) ([]string, error) { + slash := filepath.ToSlash(pattern) + index := strings.Index(slash, "**") + root := strings.TrimSuffix(slash[:index], "/") + if root == "" { + root = "." + } + rootPath := filepath.FromSlash(root) + var result []string + err := filepath.WalkDir(rootPath, func(candidate string, d os.DirEntry, walkErr error) error { + if walkErr != nil { + return walkErr + } + if d.IsDir() { + return nil + } + normalized := filepath.ToSlash(candidate) + if !globMatch(slash, normalized) { + return nil + } + result = append(result, candidate) + return nil + }) + return result, err +} + +func hasGlobMeta(pattern string) bool { + return strings.ContainsAny(pattern, "*?[") +} + +func matchesAnyPattern(candidate string, patterns []string) bool { + for _, pattern := range patterns { + pattern = strings.TrimSpace(pattern) + if pattern == "" { + continue + } + if globMatch(filepath.ToSlash(pattern), filepath.ToSlash(candidate)) { + return true + } + } + return false +} + +func globMatch(pattern, candidate string) bool { + pattern = filepath.ToSlash(pattern) + candidate = filepath.ToSlash(candidate) + if strings.Contains(pattern, "**") { + return matchDoubleStar(strings.Split(pattern, "/"), strings.Split(candidate, "/")) + } + ok, _ := path.Match(pattern, candidate) + return ok +} + +func matchDoubleStar(pattern, candidate []string) bool { + if len(pattern) == 0 { + return len(candidate) == 0 + } + head := pattern[0] + if head == "**" { + if matchDoubleStar(pattern[1:], candidate) { + return true + } + if len(candidate) > 0 { + return matchDoubleStar(pattern, candidate[1:]) + } + return false + } + if len(candidate) == 0 { + return false + } + ok, _ := path.Match(head, candidate[0]) + if !ok { + return false + } + return matchDoubleStar(pattern[1:], candidate[1:]) +} + +func isDQLSourceFile(path string) bool { + ext := strings.ToLower(strings.TrimSpace(filepath.Ext(path))) + return ext == ".dql" || ext == ".sql" +} + +func lookupRootView(resource *view.Resource, root string) *view.View { + if resource == nil { + return nil + } + name := strings.TrimSpace(root) + if name != "" { + if candidate, _ := resource.View(name); candidate != nil { + return candidate + } + } + if len(resource.Views) > 0 { + return resource.Views[0] + } + return nil +} + +type bootstrapRuleSettings struct { + Method string `json:"Method"` + URI string `json:"URI"` +} + +func resolvePathSettings(sourcePath, dql, apiPrefix string) (string, string) { + method := "GET" + uri := "" + settings := parseBootstrapRuleSettings(dql) + if settings != nil { + if candidate := strings.TrimSpace(strings.ToUpper(settings.Method)); candidate != "" { + method = candidate + } + uri = strings.TrimSpace(settings.URI) + } + if uri == "" { + stem := strings.TrimSuffix(filepath.Base(sourcePath), filepath.Ext(sourcePath)) + uri = "/" + strings.Trim(stem, "/") + if prefix := strings.TrimSpace(apiPrefix); prefix != "" { + uri = strings.TrimRight(prefix, "/") + uri + } + } + return method, uri +} + +func parseBootstrapRuleSettings(dql string) *bootstrapRuleSettings { + start := strings.Index(dql, "/*") + end := strings.Index(dql, "*/") + if start == -1 || end == -1 || end <= start+2 { + return nil + } + raw := strings.TrimSpace(dql[start+2 : end]) + if !strings.HasPrefix(raw, "{") || !strings.HasSuffix(raw, "}") { + return nil + } + ret := &bootstrapRuleSettings{} + if err := json.Unmarshal([]byte(raw), ret); err != nil { + return nil + } + return ret +} diff --git a/gateway/dql_bootstrap_test.go b/gateway/dql_bootstrap_test.go new file mode 100644 index 000000000..b36714cd0 --- /dev/null +++ b/gateway/dql_bootstrap_test.go @@ -0,0 +1,122 @@ +package gateway + +import ( + "context" + "os" + "path/filepath" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/viant/datly/repository" + "github.com/viant/datly/repository/contract" + "github.com/viant/datly/view" +) + +func TestConfigValidate_AllowsEmptyRouteURLWithDQLBootstrap(t *testing.T) { + cfg := &Config{ + ExposableConfig: ExposableConfig{ + DQLBootstrap: &DQLBootstrap{ + Sources: []string{"./testdata/*.dql"}, + }, + }, + } + require.NoError(t, cfg.Validate()) +} + +func TestConfigValidate_FailsWithoutRouteAndBootstrap(t *testing.T) { + cfg := &Config{} + require.ErrorContains(t, cfg.Validate(), "RouteURL was empty") +} + +func TestConfigValidate_FailsForEmptyBootstrapSources(t *testing.T) { + cfg := &Config{ + ExposableConfig: ExposableConfig{ + DQLBootstrap: &DQLBootstrap{}, + }, + } + require.ErrorContains(t, cfg.Validate(), "DQLBootstrap.Sources was empty") +} + +func TestDiscoverDQLBootstrapSources(t *testing.T) { + root := t.TempDir() + require.NoError(t, os.MkdirAll(filepath.Join(root, "sql", "nested"), 0o755)) + require.NoError(t, os.WriteFile(filepath.Join(root, "sql", "a.dql"), []byte("SELECT 1"), 0o644)) + require.NoError(t, os.WriteFile(filepath.Join(root, "sql", "nested", "b.sql"), []byte("SELECT 2"), 0o644)) + require.NoError(t, os.WriteFile(filepath.Join(root, "sql", "nested", "skip.dql"), []byte("SELECT 3"), 0o644)) + + sources, err := discoverDQLBootstrapSources( + []string{filepath.Join(root, "sql", "**", "*")}, + []string{filepath.Join(root, "sql", "**", "skip.dql")}, + ) + require.NoError(t, err) + require.Len(t, sources, 2) + assert.Contains(t, sources, filepath.Join(root, "sql", "a.dql")) + assert.Contains(t, sources, filepath.Join(root, "sql", "nested", "b.sql")) +} + +func TestResolvePathSettings(t *testing.T) { + method, uri := resolvePathSettings("/tmp/orders/get.dql", `/* {"Method":"POST","URI":"/v1/api/orders"} */ SELECT 1`, "/v1/api") + assert.Equal(t, "POST", method) + assert.Equal(t, "/v1/api/orders", uri) + + method, uri = resolvePathSettings("/tmp/orders/get.dql", `SELECT 1`, "/v1/api") + assert.Equal(t, "GET", method) + assert.Equal(t, "/v1/api/get", uri) +} + +func TestDQLBootstrapEffectivePrecedence(t *testing.T) { + assert.Equal(t, DQLBootstrapPrecedenceRoutesWins, (&DQLBootstrap{}).EffectivePrecedence()) + assert.Equal(t, DQLBootstrapPrecedenceDQLWins, (&DQLBootstrap{Precedence: "dql_wins"}).EffectivePrecedence()) + assert.Equal(t, DQLBootstrapPrecedenceRoutesWins, (&DQLBootstrap{Precedence: "unknown"}).EffectivePrecedence()) +} + +func TestApplyDQLBootstrap_Precedence(t *testing.T) { + ctx := context.Background() + repo, err := repository.New(ctx, repository.WithComponentURL(""), repository.WithNoPlugin()) + require.NoError(t, err) + + route := contract.Path{Method: "GET", URI: "/v1/api/test"} + repo.Register(&repository.Component{Path: route}) + connectors, err := repo.Resources().Lookup(view.ResourceConnectors) + require.NoError(t, err) + connectors.Connectors = append(connectors.Connectors, &view.Connector{ + Connection: view.Connection{ + DBConfig: view.DBConfig{ + Name: "test_conn", + Driver: "sqlite3", + DSN: "sqlite:./test.db", + }, + }, + }) + + root := t.TempDir() + source := filepath.Join(root, "test.dql") + require.NoError(t, os.WriteFile(source, []byte(`/* {"Method":"GET","URI":"/v1/api/test","Connector":"test_conn"} */ SELECT 1 AS id`), 0o644)) + srv := &Service{Config: &Config{ExposableConfig: ExposableConfig{APIPrefix: "/v1/api"}}} + + routesWins := &DQLBootstrap{ + Sources: []string{source}, + Precedence: DQLBootstrapPrecedenceRoutesWins, + } + require.NoError(t, srv.applyDQLBootstrap(ctx, repo, routesWins)) + provider, err := repo.Registry().LookupProvider(ctx, &route) + require.NoError(t, err) + require.NotNil(t, provider) + component, err := provider.Component(ctx) + require.NoError(t, err) + assert.Nil(t, component.View) + + dqlWins := &DQLBootstrap{ + Sources: []string{source}, + Precedence: DQLBootstrapPrecedenceDQLWins, + } + require.NoError(t, srv.applyDQLBootstrap(ctx, repo, dqlWins)) + provider, err = repo.Registry().LookupProvider(ctx, &route) + require.NoError(t, err) + require.NotNil(t, provider) + component, err = provider.Component(ctx) + require.NoError(t, err) + require.NotNil(t, component.View) + assert.Equal(t, "test", component.View.Name) +} diff --git a/gateway/mcp.go b/gateway/mcp.go index 45310d541..4bf053020 100644 --- a/gateway/mcp.go +++ b/gateway/mcp.go @@ -4,6 +4,12 @@ import ( "context" "encoding/json" "fmt" + "io" + "net/http" + "net/url" + "reflect" + "strings" + furl "github.com/viant/afs/url" "github.com/viant/datly/gateway/router/proxy" "github.com/viant/datly/repository" @@ -11,14 +17,10 @@ import ( "github.com/viant/datly/view/state" "github.com/viant/jsonrpc" "github.com/viant/mcp-protocol/authorization" + oauthmeta "github.com/viant/mcp-protocol/oauth2/meta" "github.com/viant/mcp-protocol/schema" serverproto "github.com/viant/mcp-protocol/server" "github.com/viant/toolbox" - "io" - "net/http" - "net/url" - "reflect" - "strings" ) func (r *Router) buildToolsIntegration(item *dpath.Item, aPath *dpath.Path, aRoute *Route, provider *repository.Provider) error { @@ -53,109 +55,259 @@ func (r *Router) buildToolsIntegration(item *dpath.Item, aPath *dpath.Path, aRou } func (r *Router) mcpToolCallHandler(component *repository.Component, aRoute *Route) serverproto.ToolHandlerFunc { - handler := func(ctx context.Context, req *schema.CallToolRequest) (*schema.CallToolResult, *jsonrpc.Error) { + return func(ctx context.Context, req *schema.CallToolRequest) (*schema.CallToolResult, *jsonrpc.Error) { params := req.Params - URI := r.matchToolCallComponentURI(aRoute, component, params) - URL := fmt.Sprintf("http://localhost/%v", strings.TrimLeft(URI, "/")) // fallback to a local URL for now, this should be replaced with the actual service URL + uri := r.matchToolCallComponentURI(aRoute, component, params) + baseURL := fmt.Sprintf("http://localhost/%v", strings.TrimLeft(uri, "/")) // replace with actual service URL when available + values := url.Values{} var body io.Reader - var uniquePath = make(map[string]bool) - var uniqueQuery = make(map[string]bool) - for _, parameter := range component.Input.Type.Parameters { - paramName := strings.Title(parameter.Name) - value := params.Arguments[paramName] - paramType := parameter.Schema.Type() - if paramType.Kind() == reflect.Ptr { - paramType = paramType.Elem() + uniquePath := map[string]bool{} + uniqueQuery := map[string]bool{} + + // 1) Collect parameters (component + selector pagination) + allParams := r.collectToolParameters(component) + + // 2) Apply parameters to request URL/query/body + for _, p := range allParams { + name := strings.Title(p.Name) + value := params.Arguments[name] + pType := p.Schema.Type() + if pType.Kind() == reflect.Ptr { + pType = pType.Elem() } + value = r.coerceNumericValue(value, pType) + var rpcErr *jsonrpc.Error + baseURL, body, rpcErr = r.applyParamToRequest(baseURL, values, p, value, uniquePath, uniqueQuery, body) + if rpcErr != nil { + return nil, rpcErr + } + } - switch paramType.Kind() { - case reflect.Int, reflect.Int64, reflect.Uint, reflect.Uint64, reflect.Float64: - if value == nil { - continue - } - value = toolbox.AsInt(value) + // 3) Finalize URL with query string + finalURL := baseURL + if enc := values.Encode(); enc != "" { + if strings.Contains(finalURL, "?") { + finalURL += "&" + enc + } else { + finalURL += "?" + enc } + } - switch parameter.In.Kind { - case state.KindPath: - if uniquePath[parameter.In.Name] { - continue - } - uniquePath[parameter.In.Name] = true + // 4) Build HTTP request and route + httpReq, rpcErr := r.newToolHTTPRequest(aRoute.Path.Method, finalURL, body) + if rpcErr != nil { + return nil, rpcErr + } + r.addAuthTokenIfPresent(ctx, httpReq) - if value == nil { - return nil, jsonrpc.NewInvalidRequest("missing path parameter: "+parameter.In.Name, nil) - } + // NEW: map MCP view sync flag argument to Sync-Read header + r.addSyncReadHeaderIfPresent(ctx, component, ¶ms, httpReq) - URL = strings.ReplaceAll(URL, "{"+parameter.In.Name+"}", fmt.Sprintf("%v", value)) - case state.KindQuery, state.KindForm: - if uniqueQuery[parameter.In.Name] { - continue - } - uniqueQuery[parameter.In.Name] = true - if value == nil || value == "" { - continue - } - // Check if value is a slice and create a comma-separated string - if slice, ok := value.([]interface{}); ok { - var items []string - for _, item := range slice { - if f, ok := item.(float64); ok { - items = append(items, fmt.Sprintf("%v", int64(f))) - } else { - items = append(items, fmt.Sprintf("%v", item)) - } - } - values.Add(parameter.In.Name, strings.Join(items, ",")) - } else { - values.Add(parameter.In.Name, fmt.Sprintf("%v", value)) - } - case state.KindRequestBody: - if text, ok := value.(string); ok { - body = strings.NewReader(text) - } else { - data, err := json.Marshal(value) - if err != nil { - return nil, jsonrpc.NewInvalidParamsError("failed to marshal request body: %w", data) - } - body = strings.NewReader(string(data)) - } + httpReq.RequestURI = httpReq.URL.RequestURI() + if uri != aRoute.URI() { + if matched, _ := r.match(component.Method, uri, httpReq); matched != nil { + aRoute = matched } } - responseWriter := proxy.NewWriter() + rw := proxy.NewWriter() + aRoute.Handle(rw, httpReq) - // Add query parameters to URL if any exist - if len(values) > 0 { - if strings.Contains(URL, "?") { - URL += "&" + values.Encode() - } else { - URL += "?" + values.Encode() + if rw.Code == http.StatusUnauthorized { + return nil, r.mcpUnauthorizedError() + } + + // 5) Build tool result (text + structured on error) + return r.buildToolCallResult(rw, finalURL, aRoute.Path.Method), nil + } +} + +func (r *Router) addSyncReadHeaderIfPresent( + ctx context.Context, + component *repository.Component, + params *schema.CallToolRequestParams, + httpRequest *http.Request, +) { + if params == nil || params.Arguments == nil { + return + } + // MCP tool arguments are generated using exported Go field names, so + // the Datly view sync flag (view.SyncFlag == "viewSyncFlag") will appear + // as "viewSyncFlag" in the schema/tool call. + const mcpSyncFlagArg = "viewSyncFlag" + const headerName = "Sync-Read" + + value, ok := params.Arguments[mcpSyncFlagArg] + if !ok { + return + } + + if !isTruthy(value) { + return + } + + // Optionally, ensure that the underlying component actually declares + // a sync flag parameter; if it does not, we simply skip setting the header. + if !hasSyncFlagParameter(component) { + return + } + + httpRequest.Header.Set(headerName, "true") +} + +// hasSyncFlagParameter checks whether the component declares a selector +// sync flag parameter, which should be exposed as view.SyncFlag. +func hasSyncFlagParameter(component *repository.Component) bool { + if component == nil || component.View == nil || component.View.Selector == nil { + return false + } + param := component.View.Selector.GetSyncFlagParameter() + if param == nil { + return false + } + // The selector sync flag parameter is defined in view.Config using + // view.SyncFlag as the state key, but here we simply check that it exists. + return true +} + +// isTruthy interprets common JSON-serialised truthy values. +func isTruthy(v interface{}) bool { + switch value := v.(type) { + case bool: + return value + case string: + s := strings.TrimSpace(strings.ToLower(value)) + return s == "true" || s == "1" || s == "yes" || s == "y" + case float64: + return value != 0 + default: + return false + } +} + +// collectToolParameters aggregates component input parameters with selector pagination (limit/offset) when available. +func (r *Router) collectToolParameters(component *repository.Component) []*state.Parameter { + var all []*state.Parameter + all = append(all, component.Input.Type.Parameters...) + if component.View != nil && component.View.Selector != nil { + if p := component.View.Selector.LimitParameter; p != nil { + all = append(all, p) + } + if p := component.View.Selector.OffsetParameter; p != nil { + all = append(all, p) + } + if p := component.View.Selector.FieldsParameter; p != nil { + all = append(all, p) + } + if p := component.View.Selector.PageParameter; p != nil { + all = append(all, p) + } + } + return all +} + +// coerceNumericValue normalizes numeric values to integers when appropriate. +func (r *Router) coerceNumericValue(value interface{}, paramType reflect.Type) interface{} { + switch paramType.Kind() { + case reflect.Int, reflect.Int64, reflect.Uint, reflect.Uint64, reflect.Float64: + if value == nil { + return nil + } + return toolbox.AsInt(value) + } + return value +} + +// applyParamToRequest applies a single parameter into path placeholders, query/form values, or request body. +func (r *Router) applyParamToRequest(baseURL string, values url.Values, p *state.Parameter, value interface{}, uniquePath, uniqueQuery map[string]bool, body io.Reader) (string, io.Reader, *jsonrpc.Error) { + switch p.In.Kind { + case state.KindPath: + if uniquePath[p.In.Name] { + return baseURL, body, nil + } + uniquePath[p.In.Name] = true + if value == nil { + // If parameter has its own URI segment configured, treat as optional and strip the placeholder. + if p.URI != "" { + baseURL = strings.ReplaceAll(baseURL, "/{"+p.In.Name+"}", "") + baseURL = strings.ReplaceAll(baseURL, "{"+p.In.Name+"}", "") + return baseURL, body, nil } + return baseURL, body, jsonrpc.NewInvalidRequest("missing path parameter: "+p.In.Name, nil) + } + baseURL = strings.ReplaceAll(baseURL, "{"+p.In.Name+"}", fmt.Sprintf("%v", value)) + case state.KindQuery, state.KindForm: + if uniqueQuery[p.In.Name] { + return baseURL, body, nil } - httpRequest, err := http.NewRequest(aRoute.Path.Method, URL, body) - if err != nil { - return nil, jsonrpc.NewInvalidRequest(err.Error(), nil) + uniqueQuery[p.In.Name] = true + if value == nil || value == "" { + return baseURL, body, nil } - r.addAuthTokenIfPresent(ctx, httpRequest) - httpRequest.RequestURI = httpRequest.URL.RequestURI() - if URI != aRoute.URI() { - if matchedRoute, _ := r.match(component.Method, URI, httpRequest); matchedRoute != nil { - aRoute = matchedRoute + if slice, ok := value.([]interface{}); ok { + var items []string + for _, item := range slice { + if f, ok := item.(float64); ok { + items = append(items, fmt.Sprintf("%v", int64(f))) + } else { + items = append(items, fmt.Sprintf("%v", item)) + } } + values.Add(p.In.Name, strings.Join(items, ",")) + } else { + values.Add(p.In.Name, fmt.Sprintf("%v", value)) } - aRoute.Handle(responseWriter, httpRequest) // route the request to the actual handler - var result = schema.CallToolResult{} - mimeType := "application/json" - item := schema.CallToolResultContentElem{ - MimeType: mimeType, - Type: "text", // use data for some clients - Text: responseWriter.Body.String(), + case state.KindRequestBody: + if text, ok := value.(string); ok { + body = strings.NewReader(text) + } else { + data, err := json.Marshal(value) + if err != nil { + return baseURL, body, jsonrpc.NewInvalidParamsError("failed to marshal request body", nil) + } + body = strings.NewReader(string(data)) } - result.Content = append(result.Content, item) - return &result, nil } - return handler + return baseURL, body, nil +} + +// newToolHTTPRequest constructs an HTTP request for routed tool invocation. +func (r *Router) newToolHTTPRequest(method, URL string, body io.Reader) (*http.Request, *jsonrpc.Error) { + httpRequest, err := http.NewRequest(method, URL, body) + if err != nil { + return nil, jsonrpc.NewInvalidRequest(err.Error(), nil) + } + return httpRequest, nil +} + +// buildToolCallResult composes a CallToolResult with text content and structured error info if status is not OK. +func (r *Router) buildToolCallResult(responseWriter *proxy.Writer, URL, method string) *schema.CallToolResult { + var result = &schema.CallToolResult{} + mimeType := responseWriter.HeaderMap.Get("Content-Type") + if mimeType == "" { + mimeType = "application/json" + } + data := responseWriter.Body.Bytes() + result.Content = append(result.Content, schema.CallToolResultContentElem( + schema.TextContent{ + Type: "text", + Text: string(data), + }, + )) + _ = json.Unmarshal(data, &result.StructuredContent) + if responseWriter.Code >= http.StatusBadRequest { + isErr := true + result.IsError = &isErr + result.StructuredContent = map[string]interface{}{ + "status": responseWriter.Code, + "error": true, + "message": responseWriter.Body.String(), + "headers": responseWriter.HeaderMap, + "uri": URL, + "method": method, + } + } + return result } func (r *Router) matchToolCallComponentURI(aRoute *Route, component *repository.Component, params schema.CallToolRequestParams) string { @@ -186,10 +338,42 @@ func (r *Router) addAuthTokenIfPresent(ctx context.Context, httpRequest *http.Re } } +const defaultMCPProtectedResource = "https://datly.viantinc.com" + +func (r *Router) mcpUnauthorizedError() *jsonrpc.Error { + if r == nil || r.config == nil || r.config.MCP == nil { + return jsonrpc.NewError(schema.Unauthorized, "Unauthorized", nil) + } + issuerURL := strings.TrimSpace(r.config.MCP.IssuerURL) + if issuerURL == "" { + return jsonrpc.NewError(schema.Unauthorized, "Unauthorized", nil) + } + return jsonrpc.NewError(schema.Unauthorized, "Unauthorized", &authorization.Authorization{ + RequiredScopes: []string{}, + UseIdToken: true, + ProtectedResourceMetadata: &oauthmeta.ProtectedResourceMetadata{ + Resource: defaultMCPProtectedResource, + AuthorizationServers: []string{issuerURL}, + }, + }) +} + func (r *Router) buildToolInputType(components *repository.Component) reflect.Type { var inputFields []reflect.StructField + var uniqueFieldName = make(map[string]bool) var uniqueQuery = make(map[string]bool) var uniquePath = make(map[string]bool) + appendField := func(name string, fieldType reflect.Type, tag reflect.StructTag) { + if name == "" || fieldType == nil { + return + } + if uniqueFieldName[name] { + return + } + uniqueFieldName[name] = true + inputFields = append(inputFields, reflect.StructField{Name: name, Type: fieldType, Tag: tag}) + } + // Include component input parameters for _, parameter := range components.Input.Type.Parameters { name := strings.Title(parameter.Name) switch parameter.In.Kind { @@ -198,20 +382,63 @@ func (r *Router) buildToolInputType(components *repository.Component) reflect.Ty continue } uniquePath[parameter.In.Name] = true - inputFields = append(inputFields, reflect.StructField{Name: name, Type: parameter.Schema.Type()}) + // If parameter is a slice, make it optional in schema via `omitempty` and optional:"true". + var tag reflect.StructTag + if parameter.Schema != nil && parameter.Schema.Type().Kind() == reflect.Slice { + tag = `json:",omitempty" optional:"true"` + } + appendField(name, parameter.Schema.Type(), tag) case state.KindQuery, state.KindForm: if uniqueQuery[parameter.In.Name] { continue } uniqueQuery[parameter.In.Name] = true + // Repeated (slice) params are optional regardless of "required" tag. + // Otherwise, respect explicit required; default to optional. tag := reflect.StructTag(parameter.Tag) - if !strings.Contains(parameter.Tag, "required") { + if parameter.Schema != nil && parameter.Schema.Type().Kind() == reflect.Slice { + tag = `json:",omitempty" optional:"true"` + } else if !strings.Contains(parameter.Tag, "required") { tag = `json:",omitempty"` } - inputFields = append(inputFields, reflect.StructField{Name: name, Type: parameter.Schema.Type(), Tag: tag}) + appendField(name, parameter.Schema.Type(), tag) case state.KindRequestBody: - inputFields = append(inputFields, reflect.StructField{Name: name, Type: parameter.Schema.Type()}) + // If body is a slice, mark optional in schema. + var tag reflect.StructTag + if parameter.Schema != nil && parameter.Schema.Type().Kind() == reflect.Slice { + tag = `json:",omitempty" optional:"true"` + } + appendField(name, parameter.Schema.Type(), tag) + } + } + + // Include selector (limit/offset/fields/page) for read components when available + if components.View != nil && components.View.Selector != nil { + if p := components.View.Selector.LimitParameter; p != nil && p.In != nil && p.In.Name != "" { + if !uniqueQuery[p.In.Name] { // avoid duplicates + uniqueQuery[p.In.Name] = true + appendField(strings.Title(p.Name), p.Schema.Type(), `json:",omitempty"`) + } + } + if p := components.View.Selector.OffsetParameter; p != nil && p.In != nil && p.In.Name != "" { + if !uniqueQuery[p.In.Name] { + uniqueQuery[p.In.Name] = true + appendField(strings.Title(p.Name), p.Schema.Type(), `json:",omitempty"`) + } + } + if p := components.View.Selector.FieldsParameter; p != nil && p.In != nil && p.In.Name != "" { + if !uniqueQuery[p.In.Name] { + uniqueQuery[p.In.Name] = true + // Fields is a []string – ensure optional in schema + appendField(strings.Title(p.Name), p.Schema.Type(), `json:",omitempty" optional:"true"`) + } + } + if p := components.View.Selector.PageParameter; p != nil && p.In != nil && p.In.Name != "" { + if !uniqueQuery[p.In.Name] { + uniqueQuery[p.In.Name] = true + appendField(strings.Title(p.Name), p.Schema.Type(), `json:",omitempty"`) + } } } @@ -229,6 +456,23 @@ func (r *Router) buildTemplateResourceIntegration(item *dpath.Item, aPath *dpath parameterNames = append(parameterNames, parameter.In.Name) } } + // Also expose view selector pagination controls in URI template if present + if provider != nil { + if comp, err := provider.Component(context.Background()); err == nil && comp.View != nil && comp.View.Selector != nil { + if p := comp.View.Selector.LimitParameter; p != nil && p.In != nil && p.In.Name != "" { + parameterNames = append(parameterNames, p.In.Name) + } + if p := comp.View.Selector.OffsetParameter; p != nil && p.In != nil && p.In.Name != "" { + parameterNames = append(parameterNames, p.In.Name) + } + if p := comp.View.Selector.FieldsParameter; p != nil && p.In != nil && p.In.Name != "" { + parameterNames = append(parameterNames, p.In.Name) + } + if p := comp.View.Selector.PageParameter; p != nil && p.In != nil && p.In.Name != "" { + parameterNames = append(parameterNames, p.In.Name) + } + } + } canBuildTemplateResource := len(parameterNames) > 0 || strings.Contains(aPath.URI, "{") if !canBuildTemplateResource { return nil @@ -261,9 +505,9 @@ func (r *Router) buildTemplateResourceIntegration(item *dpath.Item, aPath *dpath func (r *Router) reactMcpResourceHandler(mcpResourceTemplate schema.ResourceTemplate, aRoute *Route, provider *repository.Provider) func(ctx context.Context, request *schema.ReadResourceRequest) (*schema.ReadResourceResult, *jsonrpc.Error) { handler := func(ctx context.Context, request *schema.ReadResourceRequest) (*schema.ReadResourceResult, *jsonrpc.Error) { - result, err := r.handleMcpRead(ctx, &request.Params, &mcpResourceTemplate, aRoute, provider) - if err != nil { - return nil, jsonrpc.NewInternalError(err.Error(), nil) + result, rpcErr := r.handleMcpRead(ctx, &request.Params, &mcpResourceTemplate, aRoute, provider) + if rpcErr != nil { + return nil, rpcErr } if len(result) == 0 { return &schema.ReadResourceResult{Contents: []schema.ReadResourceResultContentsElem{}}, nil @@ -325,12 +569,12 @@ func (r *Router) hasMcpResource(URI string) bool { return false } -func (r *Router) handleMcpRead(ctx context.Context, params *schema.ReadResourceRequestParams, template *schema.ResourceTemplate, aRoute *Route, provider *repository.Provider) ([]schema.ReadResourceResultContentsElem, error) { +func (r *Router) handleMcpRead(ctx context.Context, params *schema.ReadResourceRequestParams, template *schema.ResourceTemplate, aRoute *Route, provider *repository.Provider) ([]schema.ReadResourceResultContentsElem, *jsonrpc.Error) { URI := furl.Path(params.Uri) URL := fmt.Sprintf("http://localhost/%v", URI) // fallback to a local URL for now, this should be replaced with the actual service URL component, err := provider.Component(ctx) // ensure the provider is initialized if err != nil { - return nil, fmt.Errorf("failed to get component from provider: %w", err) + return nil, jsonrpc.NewInternalError(fmt.Errorf("failed to get component from provider: %w", err).Error(), nil) } byLoc := make(map[string]*state.Parameter) for _, param := range component.View.GetResource().Parameters { @@ -344,6 +588,9 @@ func (r *Router) handleMcpRead(ctx context.Context, params *schema.ReadResourceR } r.addAuthTokenIfPresent(ctx, httpRequest) aRoute.Handle(responseWriter, httpRequest) // route the request to the actual handler + if responseWriter.Code == http.StatusUnauthorized { + return nil, r.mcpUnauthorizedError() + } var result []schema.ReadResourceResultContentsElem mimeType := "" if template.MimeType != nil { diff --git a/gateway/route.go b/gateway/route.go index 24e682d3d..9f5cf6eed 100644 --- a/gateway/route.go +++ b/gateway/route.go @@ -2,7 +2,11 @@ package gateway import ( "context" - "github.com/goccy/go-json" + "encoding/json" + "net/http" + "strings" + "time" + "github.com/viant/afs/url" "github.com/viant/datly/gateway/router" "github.com/viant/datly/repository" @@ -11,8 +15,8 @@ import ( "github.com/viant/datly/repository/path" vcontext "github.com/viant/datly/view/context" "github.com/viant/xdatly/handler/exec" - "net/http" - "strings" + + dlogger "github.com/viant/datly/logger" ) const ( @@ -33,6 +37,9 @@ type ( Handler func(ctx context.Context, response http.ResponseWriter, req *http.Request) `json:"-"` logging.Config Version string + + // Counter is an optional per-route metrics counter + Counter dlogger.Counter `json:"-"` } ) @@ -43,7 +50,38 @@ func (r *Route) Handle(res http.ResponseWriter, req *http.Request) int { ctx := context.Background() execContext := exec.NewContext(req.Method, req.RequestURI, req.Header, r.Version) ctx = vcontext.WithValue(ctx, exec.ContextKey, execContext) + var onDone func(time.Time, ...interface{}) int64 = nil + var start time.Time + if r.Counter != nil { + start = time.Now() + onDone = r.Counter.Begin(start) + } r.Handler(ctx, res, req) + + // finalize metrics + if onDone != nil { + end := time.Now() + onDone(end) + // Determine final status code + statusCode := execContext.StatusCode + if statusCode == 0 { + statusCode = http.StatusOK + } + // Increment error/success buckets + if statusCode >= 200 && statusCode < 300 { + r.Counter.IncrementValue("Success") + r.Counter.IncrementValue("status:2xx") + } else if statusCode >= 400 && statusCode < 500 { + r.Counter.IncrementValue("Error") + r.Counter.IncrementValue("status:4xx") + } else if statusCode >= 500 { + r.Counter.IncrementValue("Error") + r.Counter.IncrementValue("status:5xx") + } else { + // Treat other codes as success by default + r.Counter.IncrementValue("Success") + } + } if execContext.StatusCode == 0 { execContext.StatusCode = http.StatusOK } @@ -66,7 +104,7 @@ func (r *Router) NewRouteHandler(handler *router.Handler) *Route { if !strings.HasPrefix(URI, "/") { URI = "/" + URI } - return &Route{ + route := &Route{ Path: &handler.Path.Path, MCP: &handler.Path.ModelContextProtocol, Meta: &handler.Path.Meta, @@ -75,6 +113,9 @@ func (r *Router) NewRouteHandler(handler *router.Handler) *Route { Config: r.config.Logging, Version: r.config.Version, } + // Pre-register and attach per-route counter if metrics are enabled + route.Counter = r.ensureRouteCounter(context.Background(), handler.Provider) + return route } func (r *Route) URI() string { diff --git a/gateway/route_metrics.go b/gateway/route_metrics.go new file mode 100644 index 000000000..213f6c4a3 --- /dev/null +++ b/gateway/route_metrics.go @@ -0,0 +1,73 @@ +package gateway + +import ( + "context" + "path" + "strings" + "time" + + dlogger "github.com/viant/datly/logger" + "github.com/viant/datly/repository" + gprovider "github.com/viant/gmetric/provider" +) + +// ensureRouteCounter pre-registers a per-route counter and returns a logger-compatible adapter. +func (r *Router) ensureRouteCounter(ctx context.Context, prov *repository.Provider) dlogger.Counter { + if r.metrics == nil || prov == nil { + return nil + } + component, err := prov.Component(ctx) + if err != nil || component == nil || component.View == nil { + return nil + } + + v := component.View + + // Derive a stable package from resource URL similar to view.discoverPackage + pkg := "datly" + if res := v.GetResource(); res != nil { + src := res.SourceURL + // Extract the dir and find the segment after "/routes/" + parent, _ := path.Split(src) + if idx := strings.Index(parent, "/routes/"); idx != -1 { + pkg = strings.Trim(parent[idx+len("/routes/"):], "/") + } + } + + // Build a metric operation name aligned with view metrics namespace, but scoped to component URI (.request) + method := component.Path.Method + normURI := normalizeURI(component.URI) + name := strings.Trim(normURI, "/") + ".request" + name = strings.ReplaceAll(name, "/", ".") + metricName := pkg + "." + name + if method != "" && !strings.EqualFold(method, "GET") { + metricName = method + ":" + metricName + } + metricName = strings.ReplaceAll(metricName, "/", ".") + + cnt := r.metrics.LookupOperation(metricName) + if cnt == nil { + // Title: human-friendly + title := v.Name + " request" + cnt = r.metrics.MultiOperationCounter(pkg, metricName, title, time.Millisecond, time.Minute, 2, gprovider.NewBasic()) + } + return dlogger.NewCounter(cnt) +} + +// normalizeURI replaces path parameters like {id} with a constant token to limit cardinality. +func normalizeURI(uri string) string { + res := uri + for { + i := strings.Index(res, "{") + if i == -1 { + break + } + j := strings.Index(res[i:], "}") + if j == -1 { + break + } + j = i + j + 1 + res = res[:i] + "T" + res[j:] + } + return res +} diff --git a/gateway/router.go b/gateway/router.go index 5da821bd0..a63bf8a98 100644 --- a/gateway/router.go +++ b/gateway/router.go @@ -18,11 +18,14 @@ import ( "github.com/viant/datly/repository/path" "github.com/viant/datly/service/operator" "github.com/viant/datly/service/session" + "github.com/viant/datly/shared/logging" "github.com/viant/datly/view" vcontext "github.com/viant/datly/view/context" + "github.com/viant/datly/view/state/kind/locator" "github.com/viant/gmetric" serverproto "github.com/viant/mcp-protocol/server" "github.com/viant/xdatly/handler/async" + "github.com/viant/xdatly/handler/logger" hstate "github.com/viant/xdatly/handler/state" "net/http" @@ -38,6 +41,7 @@ type ( repository *repository.Service operator *operator.Service config *Config + logger logger.Logger OpenAPIInfo openapi3.Info metrics *gmetric.Service statusHandler http.Handler @@ -78,6 +82,7 @@ func NewRouter(ctx context.Context, components *repository.Service, config *Conf operator: operator.New(), apiKeyMatcher: newApiKeyMatcher(config.APIKeys), mcpRegistry: mcpRegistry, + logger: logging.New(logging.INFO, nil), } return r, r.init(ctx) } @@ -154,8 +159,10 @@ func (r *Router) HandleJob(ctx context.Context, aJob *async.Job) error { request := &http.Request{Method: aJob.Method, URL: URL, RequestURI: aPath.URI} unmarshal := aComponent.UnmarshalFunc(request) locatorOptions := append(aComponent.LocatorOptions(request, hstate.NewForm(), unmarshal)) + locatorOptions = append(locatorOptions, locator.WithLogger(r.logger)) aSession := session.New(aComponent.View, session.WithAuth(r.repository.Auth()), + session.WithLogger(r.logger), session.WithComponent(aComponent), session.WithLocatorOptions(locatorOptions...), session.WithOperate(r.operator.Operate)) @@ -342,7 +349,7 @@ func (r *Router) newMatcher(ctx context.Context) (*matcher.Matcher, []*contract. } r.EnsureCors(aPath) - aRoute := r.NewRouteHandler(router.New(aPath, provider, r.repository.Registry(), r.repository.Auth(), r.config.Version, r.config.Logging)) + aRoute := r.NewRouteHandler(router.New(aPath, provider, r.repository.Registry(), r.repository.Auth(), r.config.Version, r.config.Logging, r.logger)) routes = append(routes, aRoute) if aPath.Cors != nil { optionsPaths[aPath.URI] = append(optionsPaths[aPath.URI], aPath) diff --git a/gateway/router/handler.go b/gateway/router/handler.go index b45917b65..dbc40f3da 100644 --- a/gateway/router/handler.go +++ b/gateway/router/handler.go @@ -27,7 +27,9 @@ import ( "github.com/viant/datly/view" vcontext "github.com/viant/datly/view/context" "github.com/viant/datly/view/state" + "github.com/viant/datly/view/state/kind/locator" "github.com/viant/xdatly/handler/exec" + "github.com/viant/xdatly/handler/logger" "github.com/viant/xdatly/handler/response" hstate "github.com/viant/xdatly/handler/state" "io" @@ -54,6 +56,7 @@ type ( registry *repository.Registry auth *auth.Service logging logging.Config + logger logger.Logger } ) @@ -87,7 +90,7 @@ func (r *Handler) AuthorizeRequest(request *http.Request, aPath *path.Path) erro return nil } -func New(aPath *path.Path, provider *repository.Provider, registry *repository.Registry, authService *auth.Service, version string, config logging.Config) *Handler { +func New(aPath *path.Path, provider *repository.Provider, registry *repository.Registry, authService *auth.Service, version string, config logging.Config, logger logger.Logger) *Handler { ret := &Handler{ Path: aPath, Provider: provider, @@ -96,6 +99,7 @@ func New(aPath *path.Path, provider *repository.Provider, registry *repository.R auth: authService, Version: version, logging: config, + logger: logger, } return ret } @@ -184,6 +188,10 @@ func (r *Handler) Handle(ctx context.Context, writer http.ResponseWriter, reques http.Error(writer, err.Error(), http.StatusInternalServerError) return } + if aComponent == nil { + http.Error(writer, "component not available", http.StatusServiceUnavailable) + return + } aResponse, err := r.safelyHandleComponent(ctx, request, aComponent) if err != nil { r.writeErrorResponse(ctx, writer, aComponent, err, http.StatusBadRequest) @@ -233,6 +241,20 @@ func (r *Handler) writeErrorResponse(ctx context.Context, w http.ResponseWriter, execCtx.SetError(err) } responseStatus := r.responseStatusError(message, anObjectErr) + if aComponent == nil || aComponent.Output.Type.Parameters == nil { + errAsBytes, marshalErr := goJson.Marshal(responseStatus) + if marshalErr != nil { + w.WriteHeader(http.StatusInternalServerError) + w.Write([]byte("could not parse error message")) + return + } + if execCtx != nil { + execCtx.StatusCode = statusCode + } + w.WriteHeader(statusCode) + w.Write(errAsBytes) + return + } statusParameter := aComponent.Output.Type.Parameters.LookupByLocation(state.KindOutput, "status") if statusParameter == nil { errAsBytes, marshalErr := goJson.Marshal(responseStatus) @@ -257,11 +279,8 @@ func (r *Handler) writeErrorResponse(ctx context.Context, w http.ResponseWriter, http.Error(w, err.Error(), http.StatusInternalServerError) return } - if aComponent.Content.Marshaller.JSON.CanMarshal() { - data, err = aComponent.Marshaller.JSON.Codec.Marshal(aResponse.State()) - } else { - data, err = aComponent.Marshaller.JSON.JsonMarshaller.Marshal(aResponse.State()) - } + mf := aComponent.MarshalFunc() + data, err = mf(aResponse.State()) if err != nil { w.Write(data) if execCtx != nil { @@ -390,11 +409,14 @@ func (r *Handler) handleComponent(ctx context.Context, request *http.Request, aC anOperator := operator.New() unmarshal := aComponent.UnmarshalFunc(request) locatorOptions := append(aComponent.LocatorOptions(request, hstate.NewForm(), unmarshal)) + locatorOptions = append(locatorOptions, locator.WithLogger(r.logger)) aSession := session.New(aComponent.View, session.WithAuth(r.auth), + session.WithLogger(r.logger), session.WithComponent(aComponent), session.WithLocatorOptions(locatorOptions...), session.WithRegistry(r.registry), + session.WithOperate(anOperator.Operate)) err := aSession.InitKinds(state.KindComponent, state.KindHeader, state.KindRequestBody, state.KindForm, state.KindQuery) if err != nil { @@ -455,8 +477,10 @@ func (r *Handler) handleComponent(ctx context.Context, request *http.Request, aC options.Append(response.WithHeader("Content-Disposition", fmt.Sprintf(`attachment; filename="%s.xlsx"`, aComponent.Output.GetTitle()))) } } + // Use component-level marshaller with request-scoped options filters := aComponent.Exclusion(aSession.State()) - data, err := aComponent.Content.Marshal(format, aComponent.Output.Field(), output, filters) + mf := aComponent.MarshalFunc(repository.WithRequest(request), repository.WithFormat(format), repository.WithFilters(filters)) + data, err := mf(output) if err != nil { return nil, response.NewError(500, fmt.Sprintf("failed to marshal response: %v", err), response.WithError(err)) } @@ -494,13 +518,9 @@ func (r *Handler) marshalComponentOutput(output interface{}, aComponent *reposit case []byte: return response.NewBuffered(response.WithBytes(actual)), nil default: - var data []byte - var err error - if aComponent.Content.Marshaller.JSON.CanMarshal() { - data, err = aComponent.Content.Marshaller.JSON.Codec.Marshal(output) - } else { - data, err = aComponent.Content.Marshaller.JSON.JsonMarshaller.Marshal(output) - } + // Default to JSON marshalling using component-level marshaller + mf := aComponent.MarshalFunc() + data, err := mf(output) if err != nil { return nil, response.NewError(http.StatusInternalServerError, err.Error(), response.WithError(err)) } diff --git a/gateway/router/marshal/json/benchmark_groups_test.go b/gateway/router/marshal/json/benchmark_groups_test.go new file mode 100644 index 000000000..70a8b4e6a --- /dev/null +++ b/gateway/router/marshal/json/benchmark_groups_test.go @@ -0,0 +1,115 @@ +package json + +import ( + "testing" + "time" + + "github.com/viant/datly/gateway/router/marshal/config" + "github.com/viant/tagly/format/text" +) + +type benchBasic struct { + ID int + Name string + Score float64 + On bool +} + +type benchAdvancedChild struct { + Code string + Value int +} + +type benchAdvanced struct { + ID int + CreatedAt time.Time + Tags []string + Meta map[string]string + Items []*benchAdvancedChild + Any interface{} +} + +func benchmarkMarshaller() *Marshaller { + return New(&config.IOConfig{ + CaseFormat: text.CaseFormatLowerCamel, + TimeLayout: time.RFC3339, + }) +} + +func benchmarkBasicData() []benchBasic { + return []benchBasic{ + {ID: 1, Name: "a", Score: 1.5, On: true}, + {ID: 2, Name: "b", Score: 2.5, On: false}, + {ID: 3, Name: "c", Score: 3.5, On: true}, + } +} + +func benchmarkAdvancedData() []benchAdvanced { + now := time.Date(2026, 1, 2, 3, 4, 5, 0, time.UTC) + return []benchAdvanced{ + { + ID: 10, + CreatedAt: now, + Tags: []string{"x", "y", "z"}, + Meta: map[string]string{"count": "3", "ok": "true"}, + Items: []*benchAdvancedChild{{Code: "a", Value: 1}, {Code: "b", Value: 2}}, + Any: map[string]interface{}{"kind": "demo", "n": 1}, + }, + } +} + +func BenchmarkMarshaller_Marshal_Basic(b *testing.B) { + m := benchmarkMarshaller() + data := benchmarkBasicData() + b.ReportAllocs() + for i := 0; i < b.N; i++ { + _, err := m.Marshal(data) + if err != nil { + b.Fatal(err) + } + } +} + +func BenchmarkMarshaller_Unmarshal_Basic(b *testing.B) { + m := benchmarkMarshaller() + seed := benchmarkBasicData() + encoded, err := m.Marshal(seed) + if err != nil { + b.Fatal(err) + } + b.ReportAllocs() + for i := 0; i < b.N; i++ { + var out []benchBasic + if err = m.Unmarshal(encoded, &out); err != nil { + b.Fatal(err) + } + } +} + +func BenchmarkMarshaller_Marshal_Advanced(b *testing.B) { + m := benchmarkMarshaller() + data := benchmarkAdvancedData() + b.ReportAllocs() + for i := 0; i < b.N; i++ { + _, err := m.Marshal(data) + if err != nil { + b.Fatal(err) + } + } +} + +func BenchmarkMarshaller_Unmarshal_Advanced(b *testing.B) { + m := benchmarkMarshaller() + seed := benchmarkAdvancedData() + encoded, err := m.Marshal(seed) + if err != nil { + b.Fatal(err) + } + b.ReportAllocs() + for i := 0; i < b.N; i++ { + var out []benchAdvanced + if err = m.Unmarshal(encoded, &out); err != nil { + b.Fatal(err) + } + } +} diff --git a/gateway/router/marshal/json/cache.go b/gateway/router/marshal/json/cache.go index cd79eba61..c05189bea 100644 --- a/gateway/router/marshal/json/cache.go +++ b/gateway/router/marshal/json/cache.go @@ -3,13 +3,14 @@ package json import ( "bytes" "fmt" + "reflect" + "sync" + "github.com/viant/datly/gateway/router/marshal/config" "github.com/viant/tagly/format" "github.com/viant/tagly/format/text" "github.com/viant/xreflect" "github.com/viant/xunsafe" - "reflect" - "sync" ) var buffersPool *buffers @@ -99,30 +100,39 @@ func (m *marshallersCache) loadMarshaller(rType reflect.Type, config *config.IOC return marshaller, nil } -func (c *pathCache) loadOrGetMarshaller(rType reflect.Type, config *config.IOConfig, path string, outputPath string, tag *format.Tag, options ...interface{}) (marshaler, error) { - value, ok := c.cache.Load(rType) +func (c *pathCache) loadOrGetMarshaller(rType reflect.Type, cfg *config.IOConfig, path, outPath string, tag *format.Tag, options ...interface{}) (marshaler, error) { + + placeholder := newDeferred() + value, ok := c.cache.LoadOrStore(rType, placeholder) if ok { return value.(marshaler), nil } - aMarshaler, err := c.getMarshaller(rType, config, path, outputPath, tag, options...) - + aMarshaller, err := c.getMarshaller(rType, cfg, path, outPath, tag, options...) if err != nil { + placeholder.fail(err) // unblock anyone holding the promise + c.cache.CompareAndDelete(rType, placeholder) // allow a clean retry later return nil, err } - c.storeMarshaler(rType, aMarshaler) - return aMarshaler, nil + placeholder.setTarget(aMarshaller) // resolve success + return aMarshaller, nil } func (c *pathCache) getMarshaller(rType reflect.Type, config *config.IOConfig, path string, outputPath string, tag *format.Tag, options ...interface{}) (marshaler, error) { + if rType == nil { + return nil, fmt.Errorf("nil reflect.Type for path %q", path) + } if tag == nil { tag = &format.Tag{} } aConfig := c.parseConfig(options) - if (aConfig == nil || !aConfig.ignoreCustomUnmarshaller) && rType.Implements(unmarshallerIntoType) { - return newCustomUnmarshaller(rType, config, path, outputPath, tag, c.parent) + // Keep UnmarshalerInto precedence for non-structs; structs handled below to honor gojay first. + if rType.Kind() != reflect.Struct { + if (aConfig == nil || !aConfig.IgnoreCustomUnmarshaller) && rType.Implements(unmarshallerIntoType) { + return newCustomUnmarshaller(rType, config, path, outputPath, tag, c.parent) + } } switch rType { @@ -212,12 +222,36 @@ func (c *pathCache) getMarshaller(rType reflect.Type, config *config.IOConfig, p return newTimeMarshaller(tag, config), nil } - marshaller, err := newStructMarshaller(config, rType, path, outputPath, tag, c.parent) + // Decide if type uses gojay; build base without init to handle self-references safely. + hasMarshal := (aConfig == nil || !aConfig.IgnoreCustomMarshaller) && (rType.Implements(marshalerJSONObjectType) || reflect.PtrTo(rType).Implements(marshalerJSONObjectType)) + hasUnmarshal := (aConfig == nil || !aConfig.IgnoreCustomMarshaller) && (rType.Implements(unmarshalerJSONObjectType) || reflect.PtrTo(rType).Implements(unmarshalerJSONObjectType)) + + base, err := newStructMarshaller(config, rType, path, outputPath, tag, c.parent) if err != nil { return nil, err } - return marshaller, nil + if hasMarshal || hasUnmarshal { + // Wrap base with gojay; placeholder at loadOrGet level already breaks cycles. + wrapper := newGojayObjectMarshaller(getXType(rType), getXType(reflect.PtrTo(rType)), base, hasMarshal, hasUnmarshal) + if err := base.init(); err != nil { + return nil, err + } + return wrapper, nil + } + + // No gojay: just init base and return (placeholder already in place). + if err := base.init(); err != nil { + return nil, err + } + + // Allow custom unmarshaller on structs if defined and not ignored (only if no gojay used). + if (aConfig == nil || !aConfig.IgnoreCustomUnmarshaller) && rType.Implements(unmarshallerIntoType) { + // Avoid self-referential lookup through placeholder for the same type. + return newCustomUnmarshallerWithMarshaller(rType, config, path, outputPath, tag, c.parent, base), nil + } + + return base, nil case reflect.Interface: marshaller, err := newInterfaceMarshaller(rType, config, path, outputPath, tag, c.parent) diff --git a/gateway/router/marshal/json/coverage_additional_test.go b/gateway/router/marshal/json/coverage_additional_test.go new file mode 100644 index 000000000..ddcbc6aef --- /dev/null +++ b/gateway/router/marshal/json/coverage_additional_test.go @@ -0,0 +1,1532 @@ +package json + +import ( + "bytes" + stdjson "encoding/json" + "errors" + "reflect" + "sync" + "testing" + "time" + "unsafe" + + "github.com/francoispqt/gojay" + "github.com/stretchr/testify/require" + "github.com/viant/datly/gateway/router/marshal/config" + "github.com/viant/tagly/format" + "github.com/viant/tagly/format/text" + "github.com/viant/xunsafe" +) + +type fallbackMarshaller struct { + marshalCalled bool + unmarshalCalled bool +} + +type errMarshaller struct{} + +func (e *errMarshaller) MarshallObject(ptr unsafe.Pointer, session *MarshallSession) error { + return errors.New("marshal err") +} +func (e *errMarshaller) UnmarshallObject(pointer unsafe.Pointer, decoder *gojay.Decoder, auxiliaryDecoder *gojay.Decoder, session *UnmarshalSession) error { + return errors.New("unmarshal err") +} + +func (f *fallbackMarshaller) MarshallObject(ptr unsafe.Pointer, session *MarshallSession) error { + f.marshalCalled = true + session.WriteString(`{"fallback":true}`) + return nil +} + +func (f *fallbackMarshaller) UnmarshallObject(pointer unsafe.Pointer, decoder *gojay.Decoder, auxiliaryDecoder *gojay.Decoder, session *UnmarshalSession) error { + f.unmarshalCalled = true + return nil +} + +type gjOnlyPtr struct { + V int +} + +func (g *gjOnlyPtr) MarshalJSONObject(enc *gojay.Encoder) { + enc.IntKey("V", g.V) +} + +func (g *gjOnlyPtr) IsNil() bool { return g == nil } + +func (g *gjOnlyPtr) UnmarshalJSONObject(dec *gojay.Decoder, key string) error { + if key == "V" { + return dec.Int(&g.V) + } + return nil +} + +func (g *gjOnlyPtr) NKeys() int { return 0 } + +type customSum int +type customStruct int +type customStructHolder struct { + V int +} +type gojayBadInit struct { + C chan int +} + +type withM interface{ M() } +type withMImpl struct{} + +func (withMImpl) M() {} + +func (c *customSum) UnmarshalJSONWithOptions(dst interface{}, decoder *gojay.Decoder, options ...interface{}) error { + var vals []int + if err := decoder.SliceInt(&vals); err != nil { + return err + } + sum := 0 + for _, v := range vals { + sum += v + } + *dst.(**customSum) = (*customSum)(&sum) + return nil +} + +func (customStruct) UnmarshalJSONWithOptions(dst interface{}, decoder *gojay.Decoder, options ...interface{}) error { + var v int + if err := decoder.Int(&v); err != nil { + return err + } + p := dst.(*customStruct) + *p = customStruct(v) + return nil +} + +func (c customStructHolder) UnmarshalJSONWithOptions(dst interface{}, decoder *gojay.Decoder, options ...interface{}) error { + var v int + if err := decoder.Int(&v); err != nil { + return err + } + c.V = v + p := dst.(*customStructHolder) + *p = c + return nil +} + +func (g gojayBadInit) MarshalJSONObject(enc *gojay.Encoder) {} +func (g gojayBadInit) IsNil() bool { return false } + +func TestCoverage_OptionsAndTags(t *testing.T) { + opts := Options{&Tag{FieldName: "x"}, &format.Tag{Name: "y"}} + require.Equal(t, "x", opts.Tag().FieldName) + require.Equal(t, "y", opts.FormatTag().Name) + + parsed := Parse("name,omitempty") + require.Equal(t, "name", parsed.FieldName) + require.True(t, parsed.OmitEmpty) + + transient := Parse("-") + require.True(t, transient.Transient) + + xTag := ParseXTag("", "inline") + require.True(t, xTag.Inline) +} + +func TestCoverage_DefaultTagAndParseValue(t *testing.T) { + type sample struct { + A *int `default:"value=7,nullable=false,required=true"` + B time.Time `default:"value=2024-01-01T00:00:00Z,format=2006-01-02T15:04:05Z07:00"` + C *time.Time `default:"value=2024-01-01T00:00:00Z,format=2006-01-02T15:04:05Z07:00"` + } + rType := reflect.TypeOf(sample{}) + + aTag, err := NewDefaultTag(rType.Field(0)) + require.NoError(t, err) + require.True(t, aTag.IsRequired()) + require.False(t, aTag.IsNullable()) + + bTag, err := NewDefaultTag(rType.Field(1)) + require.NoError(t, err) + require.NotNil(t, bTag._value) + + cTag, err := NewDefaultTag(rType.Field(2)) + require.NoError(t, err) + require.NotNil(t, cTag._value) + + _, err = parseValue(reflect.TypeOf(time.Time{}), "bad-time", time.RFC3339) + require.Error(t, err) +} + +func TestCoverage_BytesSliceUnmarshal(t *testing.T) { + var b []byte + dec := gojay.BorrowDecoder(bytes.NewReader([]byte(`[1,2,3]`))) + defer dec.Release() + require.NoError(t, dec.Array(&BytesSlice{b: &b})) + require.Equal(t, []byte{1, 2, 3}, b) + + var bPtr *[]byte + dec2 := gojay.BorrowDecoder(bytes.NewReader([]byte(`[4,5]`))) + defer dec2.Release() + require.NoError(t, dec2.Array(&BytesPtrSlice{b: &bPtr})) + require.Equal(t, []byte{4, 5}, *bPtr) +} + +func TestCoverage_ErrorJoin(t *testing.T) { + err := NewError("a", errors.New("x")) + require.Contains(t, err.Error(), "failed to unmarshal a") + + nested := NewError("obj", NewError("field", errors.New("boom"))) + require.Equal(t, "obj.field", nested.Path) + + nestedArr := NewError("arr", NewError("[1]", errors.New("boom"))) + require.Equal(t, "arr[1]", nestedArr.Path) +} + +func TestCoverage_UnsignedAndPointers_MarshalUnmarshal(t *testing.T) { + type payload struct { + U uint + U8 uint8 + U16 uint16 + U32 uint32 + U64 uint64 + PU *uint + P8 *uint8 + P16 *uint16 + P32 *uint32 + P64 *uint64 + } + m := New(&config.IOConfig{}) + + u := uint(10) + u8 := uint8(11) + u16 := uint16(12) + u32 := uint32(13) + u64 := uint64(14) + in := payload{U: 1, U8: 2, U16: 3, U32: 4, U64: 5, PU: &u, P8: &u8, P16: &u16, P32: &u32, P64: &u64} + + data, err := m.Marshal(in) + require.NoError(t, err) + + var out payload + require.NoError(t, m.Unmarshal(data, &out)) + require.Equal(t, in.U, out.U) + require.Equal(t, in.U8, out.U8) + require.Equal(t, in.U16, out.U16) + require.Equal(t, in.U32, out.U32) + require.Equal(t, in.U64, out.U64) + require.NotNil(t, out.PU) + require.NotNil(t, out.P8) + require.NotNil(t, out.P16) + require.NotNil(t, out.P32) + require.NotNil(t, out.P64) +} + +func TestCoverage_ArrayAndMapEdges(t *testing.T) { + m := New(&config.IOConfig{CaseFormat: text.CaseFormatLowerUnderscore}) + + type boolArr struct { + Flags [3]bool + } + encoded, err := m.Marshal(boolArr{Flags: [3]bool{true, false, true}}) + require.NoError(t, err) + require.Contains(t, string(encoded), "[true,false,true]") + + var arrOut boolArr + err = m.Unmarshal([]byte(`{"Flags":[true,false,true]}`), &arrOut) + require.Error(t, err) // array unmarshal not supported + + type mapHolder struct { + M map[string]int + } + var mh mapHolder + require.NoError(t, m.Unmarshal([]byte(`{"M":{"a":1,"b":2}}`), &mh)) + require.Equal(t, 2, mh.M["b"]) + + type unsupported struct { + M map[string]bool + } + var bad unsupported + err = m.Unmarshal([]byte(`{"M":{"a":true}}`), &bad) + require.Error(t, err) +} + +func TestCoverage_InterfaceAndSliceInterface(t *testing.T) { + m := New(&config.IOConfig{}) + type obj struct { + Any interface{} + List []interface{} + } + var out obj + require.NoError(t, m.Unmarshal([]byte(`{"Any":{"k":1},"List":[1,"x",{"a":2}]}`), &out)) + require.Len(t, out.List, 1) // current behavior: appended as a single decoded interface payload + + encoded, err := m.Marshal(out) + require.NoError(t, err) + require.Contains(t, string(encoded), "\"List\"") +} + +func TestCoverage_CustomUnmarshallerAndGojayWrapper(t *testing.T) { + m := New(&config.IOConfig{}) + type holder struct { + Sum *customSum + G gjOnlyPtr + } + var out holder + require.NoError(t, m.Unmarshal([]byte(`{"Sum":[1,2,3],"G":{"V":7}}`), &out)) + require.NotNil(t, out.Sum) + require.Equal(t, 6, int(*out.Sum)) + require.Equal(t, 7, out.G.V) + + data, err := m.Marshal(out) + require.NoError(t, err) + require.Contains(t, string(data), `"V":7`) +} + +func TestCoverage_GojayWrapperFallbackAndDeferred(t *testing.T) { + rType := reflect.TypeOf(struct{ X int }{}) + fb := &fallbackMarshaller{} + wrapper := newGojayObjectMarshaller(getXType(rType), getXType(reflect.PtrTo(rType)), fb, true, true) + session := &MarshallSession{Buffer: bytes.NewBuffer(nil)} + val := struct{ X int }{X: 1} + require.NoError(t, wrapper.MarshallObject(AsPtr(val, rType), session)) + require.True(t, fb.marshalCalled) + + dec := gojay.BorrowDecoder(bytes.NewReader([]byte(`{"X":1}`))) + defer dec.Release() + ptr := reflect.New(rType) + require.NoError(t, wrapper.UnmarshallObject(unsafe.Pointer(ptr.Pointer()), dec, nil, &UnmarshalSession{})) + require.True(t, fb.unmarshalCalled) + + d := newDeferred() + d.fail(errors.New("boom")) + require.Error(t, d.MarshallObject(nil, &MarshallSession{Buffer: bytes.NewBuffer(nil)})) + + d2 := newDeferred() + d2.setTarget(fb) + require.NoError(t, d2.MarshallObject(nil, &MarshallSession{Buffer: bytes.NewBuffer(nil)})) +} + +func TestCoverage_PathCacheHelpers(t *testing.T) { + pc := &pathCache{cache: sync.Map{}} + fb := &fallbackMarshaller{} + pc.storeMarshaler(reflect.TypeOf(1), fb) + got, ok := pc.loadMarshaller(reflect.TypeOf(1)) + require.True(t, ok) + require.NotNil(t, got) + + cfg := pc.parseConfig([]interface{}{&cacheConfig{IgnoreCustomMarshaller: true}}) + require.NotNil(t, cfg) + require.True(t, cfg.IgnoreCustomMarshaller) +} + +func TestCoverage_TimeAndRawMessageAndAsPtrMap(t *testing.T) { + cfg := &config.IOConfig{TimeLayout: "2006-01-02T15:04:05Z07:00"} + m := New(cfg) + now := time.Now().UTC().Truncate(time.Second) + type payload struct { + T time.Time + TP *time.Time + R stdjson.RawMessage + RP *stdjson.RawMessage + } + raw := stdjson.RawMessage(`{"a":1}`) + in := payload{T: now, TP: &now, R: raw, RP: &raw} + + data, err := m.Marshal(in) + require.NoError(t, err) + + var out payload + require.NoError(t, m.Unmarshal(data, &out)) + require.Equal(t, raw, out.R) + require.NotNil(t, out.RP) + + // map branch in AsPtr + mapped := map[string]int{"a": 1} + ptr := AsPtr(mapped, reflect.TypeOf(mapped)) + require.NotNil(t, ptr) +} + +func TestCoverage_MarshalSessionOptionsAndInterceptors(t *testing.T) { + m := New(&config.IOConfig{}) + type payload struct { + Items []int + } + + session := &MarshallSession{Buffer: bytes.NewBuffer(nil)} + interceptors := MarshalerInterceptors{ + "Items": func() ([]byte, error) { return []byte(`[9,8,7]`), nil }, + } + data, err := m.Marshal(payload{Items: []int{1, 2, 3}}, session, interceptors) + require.NoError(t, err) + require.Contains(t, string(data), `"Items":[9,8,7]`) + + _, err = m.Marshal(nil) + require.NoError(t, err) +} + +func TestCoverage_PrepareUnmarshalSessionAndInterceptor(t *testing.T) { + m := New(&config.IOConfig{}) + type payload struct { + ID int + } + + um := &UnmarshalSession{} + interceptors := UnmarshalerInterceptors{ + "ID": func(dst interface{}, decoder *gojay.Decoder, options ...interface{}) error { + // consume incoming value but force a custom value + var throwaway int + if err := decoder.Int(&throwaway); err != nil { + return err + } + *dst.(*int) = 77 + return nil + }, + } + + var out payload + require.NoError(t, m.Unmarshal([]byte(`{"ID":1}`), &out, um, interceptors)) + require.Equal(t, 77, out.ID) + require.NotEmpty(t, um.Options) +} + +func TestCoverage_IntAndStringBranches(t *testing.T) { + m := New(&config.IOConfig{}) + + type ints struct { + I8 int8 + I16 int16 + I32 int32 + I64 int64 + } + var out ints + require.NoError(t, m.Unmarshal([]byte(`{"I8":8,"I16":16,"I32":32,"I64":64}`), &out)) + require.EqualValues(t, 8, out.I8) + require.EqualValues(t, 16, out.I16) + require.EqualValues(t, 32, out.I32) + require.EqualValues(t, 64, out.I64) + + sb := &MarshallSession{Buffer: bytes.NewBuffer(nil)} + marshallString("line\u2028sep\u2029par\n\t\r\b\f\"\\/", sb, nil) + require.Contains(t, sb.String(), `\u2028`) + require.Contains(t, sb.String(), `\u2029`) +} + +func TestCoverage_MapVariantsAndKeys(t *testing.T) { + m := New(&config.IOConfig{CaseFormat: text.CaseFormatLowerUnderscore}) + + type maps struct { + MI map[string]int + MF map[string]float64 + MS map[string]string + ANY map[string]interface{} + } + var out maps + require.NoError(t, m.Unmarshal([]byte(`{"MI":{"a":1},"MF":{"x":1.5},"MS":{"k":"v"}}`), &out)) + require.Equal(t, 1, out.MI["a"]) + require.Equal(t, 1.5, out.MF["x"]) + require.Equal(t, "v", out.MS["k"]) + + type intKey struct { + M map[int]string + } + enc, err := m.Marshal(intKey{M: map[int]string{1: "x", 2: "y"}}) + require.NoError(t, err) + require.Contains(t, string(enc), `"1":"x"`) + + type anyMap struct { + M map[string]interface{} + } + enc2, err := m.Marshal(anyMap{M: map[string]interface{}{"MyKey": 1}}) + require.NoError(t, err) + require.Contains(t, string(enc2), `"my_key"`) +} + +func TestCoverage_CacheDispatchAndErrors(t *testing.T) { + c := newCache() + cfg := &config.IOConfig{} + + pc := c.pathCache("x") + _, err := pc.getMarshaller(nil, cfg, "x", "x", nil) + require.Error(t, err) + + // Unsupported kind falls into default unsupported branch. + _, err = c.loadMarshaller(reflect.TypeOf(make(chan int)), cfg, "", "", nil) + require.Error(t, err) + + // Load representative kinds to exercise switch branches. + cases := []reflect.Type{ + reflect.TypeOf([2]bool{}), + reflect.TypeOf([]int{}), + reflect.TypeOf([]interface{}{}), + reflect.TypeOf(map[string]int{}), + reflect.TypeOf(time.Time{}), + reflect.TypeOf((*time.Time)(nil)), + reflect.TypeOf(""), + reflect.TypeOf(true), + reflect.TypeOf(float32(0)), + reflect.TypeOf(float64(0)), + reflect.TypeOf(int(0)), + reflect.TypeOf(uint(0)), + reflect.TypeOf((*int)(nil)), + reflect.TypeOf((*uint)(nil)), + } + for _, rType := range cases { + _, err = c.loadMarshaller(rType, cfg, "", "", nil) + require.NoError(t, err) + } +} + +func TestCoverage_OptionPresenceDefaultAndMarshalErrors(t *testing.T) { + empty := Options{123} + require.Nil(t, empty.Tag()) + require.Nil(t, empty.FormatTag()) + + _, err := getFields(reflect.TypeOf(1)) + require.Error(t, err) + + type badTag struct { + F string `default:"broken"` + } + _, err = NewDefaultTag(reflect.TypeOf(badTag{}).Field(0)) + require.Error(t, err) + + type badValue struct { + F int `default:"value=abc"` + } + _, err = NewDefaultTag(reflect.TypeOf(badValue{}).Field(0)) + require.Error(t, err) + + m := New(&config.IOConfig{}) + _, err = m.Marshal(make(chan int)) + require.Error(t, err) +} + +func TestCoverage_LowLevelBranches(t *testing.T) { + // String ensureReplacer nil branch + s := &stringMarshaller{dTag: &format.Tag{}, replacer: nil, defaultValue: `""`} + buf := &MarshallSession{Buffer: bytes.NewBuffer(nil)} + v := "abc" + require.NoError(t, s.MarshallObject(unsafe.Pointer(&v), buf)) + + // Slice interceptor error branch + errExpected := errors.New("interceptor") + _, err := New(&config.IOConfig{}).Marshal( + struct{ X []int }{X: []int{1}}, + MarshalerInterceptors{ + "X": func() ([]byte, error) { return nil, errExpected }, + }, + ) + require.Error(t, err) + + // Time unmarshal invalid input branch + tm := newTimeMarshaller(&format.Tag{}, &config.IOConfig{}) + dec := gojay.BorrowDecoder(bytes.NewReader([]byte(`"bad"`))) + defer dec.Release() + var tt time.Time + require.Panics(t, func() { + _ = tm.UnmarshallObject(unsafe.Pointer(&tt), dec, nil, &UnmarshalSession{}) + }) +} + +func TestCoverage_ZeroPercentFunctions(t *testing.T) { + // formatFloat + require.Equal(t, "1.25", formatFloat(1.25)) + + // float unmarshallers + pointer variants + type floats struct { + F32 float32 + F64 float64 + P32 *float32 + P64 *float64 + I8 *int8 + I16 *int16 + I32 *int32 + I64 *int64 + U8 *uint8 + U16 *uint16 + U32 *uint32 + U64 *uint64 + } + m := New(&config.IOConfig{}) + var out floats + err := m.Unmarshal([]byte(`{"F32":1.5,"F64":2.5,"P32":3.5,"P64":4.5,"I8":8,"I16":16,"I32":32,"I64":64,"U8":9,"U16":19,"U32":29,"U64":39}`), &out) + require.NoError(t, err) + require.NotNil(t, out.P32) + require.NotNil(t, out.P64) + require.NotNil(t, out.I8) + require.NotNil(t, out.I16) + require.NotNil(t, out.I32) + require.NotNil(t, out.I64) + require.NotNil(t, out.U8) + require.NotNil(t, out.U16) + require.NotNil(t, out.U32) + require.NotNil(t, out.U64) + + // inlinable unmarshal branch (invoke marshaller directly) + type inner struct { + A int + B string + } + type outer struct { + Inner inner `jsonx:"inline"` + } + rType := reflect.TypeOf(outer{}) + field, _ := rType.FieldByName("Inner") + ilm, err := newInlinableMarshaller(field, &config.IOConfig{}, "", "", &format.Tag{}, newCache()) + require.NoError(t, err) + var o outer + dec := gojay.BorrowDecoder(bytes.NewReader([]byte(`{"A":7,"B":"x"}`))) + defer dec.Release() + require.NoError(t, ilm.UnmarshallObject(unsafe.Pointer(&o.Inner), dec, nil, &UnmarshalSession{})) +} + +func TestCoverage_BranchHelpersAndPrimitiveMarshallers(t *testing.T) { + // isExcluded / filterByPath + ioCfg := &config.IOConfig{Exclude: map[string]bool{"A.B": true}} + require.True(t, isExcluded(nil, "B", ioCfg, "A.B")) + require.False(t, isExcluded(nil, "C", ioCfg, "A.C")) + filters := NewFilters(&FilterEntry{Path: "A", Fields: []string{"X"}}) + f, ok := filterByPath(filters, "A") + require.True(t, ok) + require.True(t, f["X"]) + _, ok = filterByPath(nil, "A") + require.False(t, ok) + + // primitive marshallers zero/non-zero branches + sb := &MarshallSession{Buffer: bytes.NewBuffer(nil)} + intV := 0 + require.NoError(t, newIntMarshaller(&format.Tag{}).MarshallObject(unsafe.Pointer(&intV), sb)) + intV = 3 + require.NoError(t, newIntMarshaller(&format.Tag{}).MarshallObject(unsafe.Pointer(&intV), sb)) + + f32 := float32(0) + require.NoError(t, newFloat32Marshaller(&format.Tag{}).MarshallObject(unsafe.Pointer(&f32), sb)) + f32 = 1.25 + require.NoError(t, newFloat32Marshaller(&format.Tag{}).MarshallObject(unsafe.Pointer(&f32), sb)) + + u := uint(0) + require.NoError(t, newUintMarshaller(&format.Tag{}).MarshallObject(unsafe.Pointer(&u), sb)) + u = 5 + require.NoError(t, newUintMarshaller(&format.Tag{}).MarshallObject(unsafe.Pointer(&u), sb)) + + b := false + require.NoError(t, newBoolMarshaller(&format.Tag{}).MarshallObject(unsafe.Pointer(&b), sb)) + b = true + require.NoError(t, newBoolMarshaller(&format.Tag{}).MarshallObject(unsafe.Pointer(&b), sb)) + + // explicit width marshaller zero/non-zero branches + i8 := int8(0) + require.NoError(t, NewInt8Marshaller(&format.Tag{}).MarshallObject(unsafe.Pointer(&i8), sb)) + i8 = 1 + require.NoError(t, NewInt8Marshaller(&format.Tag{}).MarshallObject(unsafe.Pointer(&i8), sb)) + + i16 := int16(0) + require.NoError(t, newInt16Marshaller(&format.Tag{}).MarshallObject(unsafe.Pointer(&i16), sb)) + i16 = 2 + require.NoError(t, newInt16Marshaller(&format.Tag{}).MarshallObject(unsafe.Pointer(&i16), sb)) + + i32 := int32(0) + require.NoError(t, newInt32Marshaller(&format.Tag{}).MarshallObject(unsafe.Pointer(&i32), sb)) + i32 = 3 + require.NoError(t, newInt32Marshaller(&format.Tag{}).MarshallObject(unsafe.Pointer(&i32), sb)) + + i64 := int64(0) + require.NoError(t, newInt64Marshaller(&format.Tag{}).MarshallObject(unsafe.Pointer(&i64), sb)) + i64 = 4 + require.NoError(t, newInt64Marshaller(&format.Tag{}).MarshallObject(unsafe.Pointer(&i64), sb)) + + u8 := uint8(0) + require.NoError(t, newUint8Marshaller(&format.Tag{}).MarshallObject(unsafe.Pointer(&u8), sb)) + u8 = 1 + require.NoError(t, newUint8Marshaller(&format.Tag{}).MarshallObject(unsafe.Pointer(&u8), sb)) + + u16 := uint16(0) + require.NoError(t, newUint16Marshaller(&format.Tag{}).MarshallObject(unsafe.Pointer(&u16), sb)) + u16 = 2 + require.NoError(t, newUint16Marshaller(&format.Tag{}).MarshallObject(unsafe.Pointer(&u16), sb)) + + u32 := uint32(0) + require.NoError(t, newUint32Marshaller(&format.Tag{}).MarshallObject(unsafe.Pointer(&u32), sb)) + u32 = 3 + require.NoError(t, newUint32Marshaller(&format.Tag{}).MarshallObject(unsafe.Pointer(&u32), sb)) + + u64 := uint64(0) + require.NoError(t, newUint64Marshaller(&format.Tag{}).MarshallObject(unsafe.Pointer(&u64), sb)) + u64 = 4 + require.NoError(t, newUint64Marshaller(&format.Tag{}).MarshallObject(unsafe.Pointer(&u64), sb)) +} + +func TestCoverage_InterfaceArrayRawAndWrapperBranches(t *testing.T) { + // interface marshaller hasMethod=true branch + v := withM(withMImpl{}) + im, err := newInterfaceMarshaller(reflect.TypeOf((*withM)(nil)).Elem(), &config.IOConfig{}, "", "", &format.Tag{}, newCache()) + require.NoError(t, err) + require.NotNil(t, im.AsInterface(unsafe.Pointer(&v))) + require.NotNil(t, asInterface(im.xType, unsafe.Pointer(&v))) + + // array unmarshal null path + am, err := newArrayMarshaller(reflect.TypeOf([2]bool{}), &config.IOConfig{}, "", "", &format.Tag{}, newCache()) + require.NoError(t, err) + decNull := gojay.BorrowDecoder(bytes.NewReader([]byte(`null`))) + defer decNull.Release() + var arr [2]bool + require.Error(t, am.UnmarshallObject(unsafe.Pointer(&arr), decNull, nil, &UnmarshalSession{})) + + // raw message marshal nil path + rm := newRawMessageMarshaller() + sb := &MarshallSession{Buffer: bytes.NewBuffer(nil)} + var raw []byte + require.NoError(t, rm.MarshallObject(unsafe.Pointer(&raw), sb)) + raw = []byte(`{"x":1}`) + require.NoError(t, rm.MarshallObject(unsafe.Pointer(&raw), sb)) + + // gojay wrapper useMarshal/useUnmarshal false branches + fb := &fallbackMarshaller{} + rType := reflect.TypeOf(struct{ A int }{}) + w := newGojayObjectMarshaller(getXType(rType), getXType(reflect.PtrTo(rType)), fb, false, false) + val := struct{ A int }{A: 1} + require.NoError(t, w.MarshallObject(AsPtr(val, rType), sb)) + require.True(t, fb.marshalCalled) + dec := gojay.BorrowDecoder(bytes.NewReader([]byte(`{"A":1}`))) + defer dec.Release() + ptr := reflect.New(rType) + require.NoError(t, w.UnmarshallObject(unsafe.Pointer(ptr.Pointer()), dec, nil, &UnmarshalSession{})) + require.True(t, fb.unmarshalCalled) + +} + +func TestCoverage_GojayWrapperPointerPathAndSkipNull(t *testing.T) { + // use existing gjOnlyPtr type to hit useMarshal/useUnmarshal=true path + w := newGojayObjectMarshaller(getXType(reflect.TypeOf(gjOnlyPtr{})), getXType(reflect.TypeOf(&gjOnlyPtr{})), &fallbackMarshaller{}, true, true) + sb := &MarshallSession{Buffer: bytes.NewBuffer(nil)} + v := gjOnlyPtr{V: 9} + require.NoError(t, w.MarshallObject(AsPtr(v, reflect.TypeOf(v)), sb)) + + dec := gojay.BorrowDecoder(bytes.NewReader([]byte(`{"V":9}`))) + defer dec.Release() + p := reflect.New(reflect.TypeOf(gjOnlyPtr{})) + require.NoError(t, w.UnmarshallObject(unsafe.Pointer(p.Pointer()), dec, nil, &UnmarshalSession{})) + + // skipNull true/false branches + decNull := gojay.BorrowDecoder(bytes.NewReader([]byte(`null`))) + defer decNull.Release() + _ = skipNull(decNull) + + decNonNull := gojay.BorrowDecoder(bytes.NewReader([]byte(`[]`))) + defer decNonNull.Release() + require.False(t, skipNull(decNonNull)) +} + +func TestCoverage_LowFunctionsExtra(t *testing.T) { + // uint ptr marshaller non-nil branch + up := uint(11) + ptr := &up + upp := &ptr + sb := &MarshallSession{Buffer: bytes.NewBuffer(nil)} + require.NoError(t, newUintPtrMarshaller(&format.Tag{}).MarshallObject(unsafe.Pointer(&upp), sb)) + + // gojay wrapper nil ptr marshal branch + w := newGojayObjectMarshaller( + getXType(reflect.TypeOf(gjOnlyPtr{})), + getXType(reflect.TypeOf(&gjOnlyPtr{})), + &fallbackMarshaller{}, + true, + true, + ) + require.NoError(t, w.MarshallObject(nil, sb)) + + // slice decoder error branch + sd := newSliceDecoder(reflect.TypeOf(0), unsafe.Pointer(&[]int{}), xunsafe.NewSlice(reflect.TypeOf([]int{}), xunsafe.UseItemAddrOpt(true)), newIntMarshaller(&format.Tag{}), &UnmarshalSession{}) + dec := gojay.BorrowDecoder(bytes.NewReader([]byte(`["x"]`))) + defer dec.Release() + _ = dec.Array(sd) +} + +func TestCoverage_RemainingBranches(t *testing.T) { + // skipNull branches + origData, origCur := decData, decCur + decData, decCur = nil, nil + decDummy := gojay.BorrowDecoder(bytes.NewReader([]byte(`null`))) + require.False(t, skipNull(decDummy)) + decDummy.Release() + decData, decCur = origData, origCur + + decNull := gojay.BorrowDecoder(bytes.NewReader([]byte(`null`))) + _ = skipNull(decNull) + decNull.Release() + + // force internal decoder state to hit skipNull true path + forced := gojay.BorrowDecoder(bytes.NewReader([]byte(`[]`))) + decPtr := unsafe.Pointer(forced) + decData.SetBytes(decPtr, []byte("null")) + decCur.SetInt(decPtr, 0) + require.True(t, skipNull(forced)) + forced.Release() + + // slice marshaller constructor error path + _, err := newSliceMarshaller(reflect.TypeOf([]chan int{}), &config.IOConfig{}, "", "", &format.Tag{}, newCache()) + require.Error(t, err) + + // slice marshaller marshal nil ptr branch + sNoop := &sliceMarshaller{path: "p", xslice: xunsafe.NewSlice(reflect.TypeOf([]int{}), xunsafe.UseItemAddrOpt(true))} + sb := &MarshallSession{Buffer: bytes.NewBuffer(nil)} + require.NoError(t, sNoop.MarshallObject(nil, sb)) + + // slice marshaller interceptor error branch + sInt := &sliceMarshaller{path: "p", xslice: xunsafe.NewSlice(reflect.TypeOf([]int{}), xunsafe.UseItemAddrOpt(true))} + sb2 := &MarshallSession{ + Buffer: bytes.NewBuffer(nil), + Interceptors: MarshalerInterceptors{ + "p": func() ([]byte, error) { return nil, errors.New("x") }, + }, + } + var arr []int + require.Error(t, sInt.MarshallObject(unsafe.Pointer(&arr), sb2)) + + // slice decoder error wrapping branch + sd := &sliceDecoder{ + appender: xunsafe.NewSlice(reflect.TypeOf([]int{}), xunsafe.UseItemAddrOpt(true)).Appender(unsafe.Pointer(&arr)), + unmarshaller: &errMarshaller{}, + } + decArr := gojay.BorrowDecoder(bytes.NewReader([]byte(`[1]`))) + err = decArr.Array(sd) + decArr.Release() + require.Error(t, err) + + // slice interface marshaller branches + sim := newSliceInterfaceMarshaller(&config.IOConfig{}, "", "", &format.Tag{}, newCache()).(*sliceInterfaceMarshaller) + ifaces := []interface{}{nil, (*int)(nil), map[string]int{"a": 1}} + sb3 := &MarshallSession{Buffer: bytes.NewBuffer(nil)} + require.NoError(t, sim.MarshallObject(unsafe.Pointer(&ifaces), sb3)) + ifacesBad := []interface{}{make(chan int)} + require.Error(t, sim.MarshallObject(unsafe.Pointer(&ifacesBad), sb3)) + decBad := gojay.BorrowDecoder(bytes.NewReader([]byte(`{`))) + require.Error(t, sim.UnmarshallObject(unsafe.Pointer(&[]interface{}{}), decBad, nil, &UnmarshalSession{})) + decBad.Release() + + // ptr marshaller constructor and branches + _, err = newPtrMarshaller(reflect.TypeOf((*chan int)(nil)), &config.IOConfig{}, "", "", &format.Tag{}, newCache()) + require.Error(t, err) + pm := &ptrMarshaller{rType: reflect.TypeOf((*int)(nil)), marshaler: newIntMarshaller(&format.Tag{})} + require.NoError(t, pm.MarshallObject(nil, sb)) + var pnil *int + require.NoError(t, pm.MarshallObject(unsafe.Pointer(&pnil), sb)) + decPtr2 := gojay.BorrowDecoder(bytes.NewReader([]byte(`null`))) + require.NoError(t, pm.UnmarshallObject(unsafe.Pointer(&pnil), decPtr2, nil, &UnmarshalSession{})) + decPtr2.Release() + + // map marshaller direct key switch and nil map branches + mm := &mapMarshaller{ + xType: getXType(reflect.TypeOf(map[int]int{})), + keyType: reflect.TypeOf(""), + valueType: reflect.TypeOf(int(0)), + keyMarshaller: newStringMarshaller(&format.Tag{}), + valueMarshaller: newIntMarshaller(&format.Tag{}), + config: &config.IOConfig{}, + } + mval := map[int]int{1: 2} + require.NoError(t, mm.MarshallObject(unsafe.Pointer(&mval), &MarshallSession{Buffer: bytes.NewBuffer(nil)})) + + mm64 := &mapMarshaller{ + xType: getXType(reflect.TypeOf(map[uint64]int{})), + keyType: reflect.TypeOf(""), + valueType: reflect.TypeOf(int(0)), + keyMarshaller: newStringMarshaller(&format.Tag{}), + valueMarshaller: newIntMarshaller(&format.Tag{}), + config: &config.IOConfig{}, + } + m64 := map[uint64]int{7: 1} + require.NoError(t, mm64.MarshallObject(unsafe.Pointer(&m64), &MarshallSession{Buffer: bytes.NewBuffer(nil)})) + + mmNil := &mapMarshaller{ + xType: getXType(reflect.TypeOf(map[string]int{})), + keyType: reflect.TypeOf(""), + valueType: reflect.TypeOf(int(0)), + keyMarshaller: newStringMarshaller(&format.Tag{}), + valueMarshaller: newIntMarshaller(&format.Tag{}), + config: &config.IOConfig{}, + } + var nilMap map[string]int + require.NoError(t, mmNil.MarshallObject(unsafe.Pointer(&nilMap), &MarshallSession{Buffer: bytes.NewBuffer(nil)})) + + // map unmarshaler error branches + mi := &mapStringIntUnmarshaler{aMap: map[string]int{}} + d1 := gojay.BorrowDecoder(bytes.NewReader([]byte(`{`))) + require.Error(t, mi.UnmarshalJSONObject(d1, "a")) + d1.Release() + mf := &mapStringFloatUnmarshaler{aMap: map[string]float64{}} + d2 := gojay.BorrowDecoder(bytes.NewReader([]byte(`{`))) + require.Error(t, mf.UnmarshalJSONObject(d2, "a")) + d2.Release() + ms := &mapStringStringUnmarshaler{aMap: map[string]string{}} + d3 := gojay.BorrowDecoder(bytes.NewReader([]byte(`{`))) + require.Error(t, ms.UnmarshalJSONObject(d3, "a")) + d3.Release() + + // struct helper normalized exclusion branch + ioCfg := &config.IOConfig{Exclude: map[string]bool{"ab": true}} + require.True(t, isExcluded(nil, "X", ioCfg, "A_B")) + + // cache path load miss branch + pc := &pathCache{cache: sync.Map{}} + _, ok := pc.loadMarshaller(reflect.TypeOf(123)) + require.False(t, ok) +} + +func TestCoverage_ConstructorAndNilVariants(t *testing.T) { + nullableTag := &format.Tag{} + b := true + nullableTag.Nullable = &b + nonNullableTag := &format.Tag{} + f := false + nonNullableTag.Nullable = &f + + // bool/string/float/int ctor nullable branches + require.Equal(t, null, newBoolMarshaller(nullableTag).zeroValue) + require.Equal(t, null, newStringMarshaller(nullableTag).defaultValue) + require.Equal(t, null, newFloat32Marshaller(nullableTag).zeroValue) + require.Equal(t, null, newFloat64Marshaller(nullableTag).zeroValue) + require.Equal(t, null, newInt64Marshaller(nullableTag).zeroValue) + require.Equal(t, null, intZeroValue(nullableTag)) + require.Equal(t, "0", intZeroValue(nonNullableTag)) + + // time ptr ctor branch + require.Equal(t, "null", newTimePtrMarshaller(nullableTag, &config.IOConfig{}).zeroValue) + require.NotEqual(t, "null", newTimePtrMarshaller(nonNullableTag, &config.IOConfig{}).zeroValue) + + // uint ptr marshaller both nil-pointer branches + um := newUintPtrMarshaller(&format.Tag{}) + sb := &MarshallSession{Buffer: bytes.NewBuffer(nil)} + require.NoError(t, um.MarshallObject(nil, sb)) + var x *uint + require.NoError(t, um.MarshallObject(unsafe.Pointer(&x), sb)) + v := uint(1) + x = &v + require.NoError(t, um.MarshallObject(unsafe.Pointer(&x), sb)) + + // ptr marshaller branch where ptr non-nil but deref nil + pm := &ptrMarshaller{rType: reflect.TypeOf((*int)(nil)), marshaler: newIntMarshaller(&format.Tag{})} + var pi *int + piPtr := &pi + require.NoError(t, pm.MarshallObject(unsafe.Pointer(piPtr), sb)) + + // interface marshaller error branch + im, err := newInterfaceMarshaller(reflect.TypeOf((*interface{})(nil)).Elem(), &config.IOConfig{}, "", "", &format.Tag{}, newCache()) + require.NoError(t, err) + iface := interface{}(make(chan int)) + require.Error(t, im.MarshallObject(unsafe.Pointer(&iface), sb)) + + // raw message marshal ptr nil and unmarshal invalid + rm := newRawMessageMarshaller() + require.NoError(t, rm.MarshallObject(nil, sb)) + decInvalid := gojay.BorrowDecoder(bytes.NewReader([]byte(`{`))) + defer decInvalid.Release() + var raw []byte + require.Error(t, rm.UnmarshallObject(unsafe.Pointer(&raw), decInvalid, nil, &UnmarshalSession{})) + + // array marshaller unsupported branch + am, err := newArrayMarshaller(reflect.TypeOf([1]int{}), &config.IOConfig{}, "", "", &format.Tag{}, newCache()) + require.NoError(t, err) + a := [1]int{1} + require.Error(t, am.MarshallObject(unsafe.Pointer(&a), sb)) + + // force array unmarshal null fast-path + decArrNull := gojay.BorrowDecoder(bytes.NewReader([]byte(`[]`))) + decArrPtr := unsafe.Pointer(decArrNull) + decData.SetBytes(decArrPtr, []byte("null")) + decCur.SetInt(decArrPtr, 0) + require.NoError(t, am.UnmarshallObject(unsafe.Pointer(&a), decArrNull, nil, &UnmarshalSession{})) + decArrNull.Release() + + // custom unmarshaller constructor error path + _, err = newCustomUnmarshaller(reflect.TypeOf(make(chan int)), &config.IOConfig{}, "", "", &format.Tag{}, newCache()) + require.Error(t, err) + + // inlinable marshaller constructor error path + type badInline struct{ C chan int } + field, _ := reflect.TypeOf(badInline{}).FieldByName("C") + _, err = newInlinableMarshaller(field, &config.IOConfig{}, "", "", &format.Tag{}, newCache()) + require.Error(t, err) + + // decoderError nil field branch + oldErr := decErr + decErr = nil + dec := gojay.BorrowDecoder(bytes.NewReader([]byte(`{}`))) + require.NoError(t, decoderError(dec)) + dec.Release() + decErr = oldErr +} + +func TestCoverage_AdditionalLowBranches(t *testing.T) { + // string marshaller empty + nullable branch + nullable := &format.Tag{} + tval := true + nullable.Nullable = &tval + sm := newStringMarshaller(nullable) + sb := &MarshallSession{Buffer: bytes.NewBuffer(nil)} + empty := "" + require.NoError(t, sm.MarshallObject(unsafe.Pointer(&empty), sb)) + + // deferred unmarshal fail branch + d := newDeferred() + d.fail(errors.New("uerr")) + dec := gojay.BorrowDecoder(bytes.NewReader([]byte(`{}`))) + defer dec.Release() + require.Error(t, d.UnmarshallObject(nil, dec, nil, &UnmarshalSession{})) + + // ptr unmarshal pointer==nil and auxiliary decoder branch + pm := &ptrMarshaller{rType: reflect.TypeOf((*int)(nil)), marshaler: newIntMarshaller(&format.Tag{})} + require.NoError(t, pm.UnmarshallObject(nil, dec, nil, &UnmarshalSession{})) + var p *int + aux := gojay.BorrowDecoder(bytes.NewReader([]byte(`1`))) + defer aux.Release() + require.NoError(t, pm.UnmarshallObject(unsafe.Pointer(&p), dec, aux, &UnmarshalSession{})) + + // enc BytesSlice error branch + bs := &BytesSlice{b: &[]byte{}} + bad := gojay.BorrowDecoder(bytes.NewReader([]byte(`"x"`))) + defer bad.Release() + _ = bs.UnmarshalJSONArray(bad) + + // struct marshaller branches with anonymous ptr embed and ignores + type Emb struct { + Arr []int + S string + } + type HolderNoOmit struct { + *Emb + Hidden string `json:"-"` + Internal string `internal:"true"` + Name string + } + type HolderOmit struct { + *Emb `json:",omitempty"` + Name string + } + m := New(&config.IOConfig{}) + _, err := m.Marshal(HolderNoOmit{Name: "x"}) // nil embed no omitempty -> explicit null paths + require.NoError(t, err) + _, err = m.Marshal(HolderOmit{Name: "y"}) // nil embed with omitempty -> skip path + require.NoError(t, err) + + // createStructMarshallers self-reference skip branch + type Self struct { + Child []*Self + Name string + } + s, err := newStructMarshaller(&config.IOConfig{}, reflect.TypeOf(Self{}), "", "", &format.Tag{}, newCache()) + require.NoError(t, err) + gf := groupFields(reflect.TypeOf(Self{})) + mrs, err := s.createStructMarshallers(gf, "", "", &format.Tag{}) + require.NoError(t, err) + // only Name should remain (Child is self-reference and skipped) + require.Len(t, mrs, 1) + + // enc.go error path: decoder exhausted + bufBytes := []byte{} + bs2 := &BytesSlice{b: &bufBytes} + dEmpty := gojay.BorrowDecoder(bytes.NewReader([]byte{})) + defer dEmpty.Release() + require.Error(t, bs2.UnmarshalJSONArray(dEmpty)) + + // default Init unknown attr ignored, malformed kv errors + type withDefault struct { + V string `default:"unknown=1,value=x"` + } + _, err = NewDefaultTag(reflect.TypeOf(withDefault{}).Field(0)) + require.NoError(t, err) + type badDefault struct { + V string `default:"badformat"` + } + _, err = NewDefaultTag(reflect.TypeOf(badDefault{}).Field(0)) + require.Error(t, err) + + // presence updater error path + type wrongPresence struct { + Has int `setMarker:"true"` + } + _, err = newPresenceUpdater(reflect.TypeOf(wrongPresence{}).Field(0)) + require.Error(t, err) + + // formatName remaining ID branch variants + require.Equal(t, "ID", formatName("ID", text.CaseFormatUpper)) + require.Equal(t, "id", formatName("ID", text.CaseFormatLower)) +} + +func TestCoverage_StructHeavyBranches(t *testing.T) { + // init() error path via unsupported field type + type Bad struct { + C chan int + } + smBad, err := newStructMarshaller(&config.IOConfig{}, reflect.TypeOf(Bad{}), "", "", &format.Tag{}, newCache()) + require.NoError(t, err) + require.Error(t, smBad.init()) + + // init() error path via invalid presence marker type + type BadPresence struct { + ID int + Has int `setMarker:"true"` + } + smBadPresence, err := newStructMarshaller(&config.IOConfig{}, reflect.TypeOf(BadPresence{}), "", "", &format.Tag{}, newCache()) + require.NoError(t, err) + require.Error(t, smBadPresence.init()) + + // UnmarshalObject branch where marker holder is non-pointer struct (lines 85-87) + type HasStruct struct{ ID bool } + type MarkerStruct struct { + ID int + Has HasStruct `setMarker:"true"` + } + smMarker, err := newStructMarshaller(&config.IOConfig{}, reflect.TypeOf(MarkerStruct{}), "", "", &format.Tag{}, newCache()) + require.NoError(t, err) + require.NoError(t, smMarker.init()) + var ms MarkerStruct + dec := gojay.BorrowDecoder(bytes.NewReader([]byte(`{"ID":1}`))) + require.NoError(t, smMarker.UnmarshallObject(unsafe.Pointer(&ms), dec, nil, &UnmarshalSession{})) + dec.Release() + + // MarshallObject nil pointer branch (lines 106-109) + sb := &MarshallSession{Buffer: bytes.NewBuffer(nil)} + require.NoError(t, smMarker.MarshallObject(nil, sb)) + + // MarshallObject filter exclusion branch + filter miss branch in isExcluded + mCfg := &config.IOConfig{} + smFilt, err := newStructMarshaller(mCfg, reflect.TypeOf(struct { + A int + B int + }{}), "", "", &format.Tag{}, newCache()) + require.NoError(t, err) + require.NoError(t, smFilt.init()) + v := struct { + A int + B int + }{A: 1, B: 2} + filtered := &MarshallSession{ + Buffer: bytes.NewBuffer(nil), + Filters: NewFilters(&FilterEntry{Path: "", Fields: []string{"A"}}), + } + require.NoError(t, smFilt.MarshallObject(unsafe.Pointer(&v), filtered)) + + // newFieldMarshaller anonymous error path (line 295) via anonymous bad struct + type anonInner struct { + C chan int + } + type anonBad struct { + anonInner + } + smAnonBad, err := newStructMarshaller(&config.IOConfig{}, reflect.TypeOf(anonBad{}), "", "", &format.Tag{}, newCache()) + require.NoError(t, err) + require.Error(t, smAnonBad.init()) + + // newFieldMarshaller ignore path (line 316) is unreachable for valid Go identifiers; keep behavior documented. + + // isZeroValue default false path (line 434) with non-comparable func type + type fnHolder struct { + F func() + } + ff, _ := reflect.TypeOf(fnHolder{}).FieldByName("F") + fmwf := &marshallerWithField{xField: xunsafe.NewField(ff), marshallerMetadata: marshallerMetadata{comparable: false}} + require.False(t, isZeroValue(unsafe.Pointer(&fnHolder{}), fmwf, nil)) + + // structDecoder interceptor error path (lines 494-496) + type simple struct{ A int } + smSimple, err := newStructMarshaller(&config.IOConfig{}, reflect.TypeOf(simple{}), "", "", &format.Tag{}, newCache()) + require.NoError(t, err) + require.NoError(t, smSimple.init()) + sv := simple{} + ud := &structDecoder{ + ptr: unsafe.Pointer(&sv), + marshaller: smSimple, + session: &UnmarshalSession{PathMarshaller: UnmarshalerInterceptors{ + "A": func(dst interface{}, decoder *gojay.Decoder, options ...interface{}) error { + return errors.New("intercept") + }, + }}, + } + dec2 := gojay.BorrowDecoder(bytes.NewReader([]byte(`1`))) + err = ud.unmarshalJson(dec2, "A") + dec2.Release() + require.Error(t, err) +} + +type gjValueOnly struct { + V int +} + +func (g gjValueOnly) MarshalJSONObject(enc *gojay.Encoder) { + enc.IntKey("V", g.V) +} + +func (g gjValueOnly) IsNil() bool { return false } + +func TestCoverage_TargetedReachableBranches(t *testing.T) { + // default.Init empty tag branch + ignorecaseformatter attribute branch + type noDefault struct { + A int + } + aTag := &DefaultTag{} + require.NoError(t, aTag.Init(reflect.TypeOf(noDefault{}).Field(0))) + + type attrDefault struct { + A string `default:"value=x,ignorecaseformatter=true"` + } + aTag2, err := NewDefaultTag(reflect.TypeOf(attrDefault{}).Field(0)) + require.NoError(t, err) + require.True(t, aTag2.IgnoreCaseFormatter) + + // parseValue time branch with empty format fallback + parsed, err := parseValue(reflect.TypeOf(time.Time{}), "2024-01-01T00:00:00Z", "") + require.NoError(t, err) + require.IsType(t, time.Time{}, parsed) + + // namesCaseIndex undefined format branch + n := &namesCaseIndex{registry: map[text.CaseFormat]map[string]string{}} + require.Equal(t, "a_b", n.formatTo("a-b", text.CaseFormatLowerUnderscore)) + + // marshal.prepareMarshallSession nil option and filters branch + j := New(&config.IOConfig{}) + sess, putBack := j.prepareMarshallSession([]interface{}{nil, []*FilterEntry{{Path: "", Fields: []string{"A"}}}}) + require.True(t, putBack) + require.NotNil(t, sess.Filters) + + // marshal.Unmarshal error branch when marshaller construction fails + err = j.Unmarshal([]byte(`1`), make(chan int)) + require.Error(t, err) + + // cache getMarshaller ptr fallback error, map error, and custom-unmarshaller struct branch + c := newCache() + _, err = c.loadMarshaller(reflect.TypeOf((*chan int)(nil)), &config.IOConfig{}, "", "", nil) + require.Error(t, err) + _, err = c.loadMarshaller(reflect.TypeOf(map[string]chan int{}), &config.IOConfig{}, "", "", nil) + require.Error(t, err) + + _, err = c.loadMarshaller(reflect.TypeOf(customStruct(0)), &config.IOConfig{}, "", "", nil) + require.NoError(t, err) + + // deferred.resolved nil-target branch + d := newDeferred() + close(d.ready) + _, err = d.resolved() + require.Error(t, err) + + // gojay wrapper value-receiver marshal branch + auxiliary decoder branch + fb := &fallbackMarshaller{} + gw := newGojayObjectMarshaller( + getXType(reflect.TypeOf(gjValueOnly{})), + getXType(reflect.TypeOf(&gjValueOnly{})), + fb, + true, + false, + ) + sb := &MarshallSession{Buffer: bytes.NewBuffer(nil)} + val := gjValueOnly{V: 5} + require.NoError(t, gw.MarshallObject(AsPtr(val, reflect.TypeOf(val)), sb)) + require.Contains(t, sb.String(), `"V":5`) + + dec := gojay.BorrowDecoder(bytes.NewReader([]byte(`{"V":6}`))) + aux := gojay.BorrowDecoder(bytes.NewReader([]byte(`{"V":7}`))) + defer dec.Release() + defer aux.Release() + dst := gjValueOnly{} + require.NoError(t, gw.UnmarshallObject(unsafe.Pointer(&dst), dec, aux, &UnmarshalSession{})) + require.True(t, fb.unmarshalCalled) + + // custom marshaller fallback branch + cm := &customMarshaller{ + valueType: getXType(reflect.TypeOf(1)), + addrType: getXType(reflect.TypeOf(new(int))), + marshaller: fb, + } + dec2 := gojay.BorrowDecoder(bytes.NewReader([]byte(`1`))) + defer dec2.Release() + i := 0 + require.NoError(t, cm.UnmarshallObject(unsafe.Pointer(&i), dec2, nil, &UnmarshalSession{})) + require.True(t, fb.unmarshalCalled) + + // array len==0 branch + am, err := newArrayMarshaller(reflect.TypeOf([0]bool{}), &config.IOConfig{}, "", "", &format.Tag{}, newCache()) + require.NoError(t, err) + sbArr := &MarshallSession{Buffer: bytes.NewBuffer(nil)} + arr := [0]bool{} + require.NoError(t, am.MarshallObject(unsafe.Pointer(&arr), sbArr)) + require.Equal(t, "[]", sbArr.String()) + + // ptr unmarshal decoder error branch + pm := &ptrMarshaller{rType: reflect.TypeOf((*int)(nil)), marshaler: newIntMarshaller(&format.Tag{})} + badDec := gojay.BorrowDecoder(bytes.NewReader([]byte(`{`))) + defer badDec.Release() + var pi *int + require.Error(t, pm.UnmarshallObject(unsafe.Pointer(&pi), badDec, nil, &UnmarshalSession{})) + + // slice unmarshal decoder.Array error and marshaller error branch + sm := &sliceMarshaller{ + elemType: reflect.TypeOf(0), + marshaller: &errMarshaller{}, + xslice: xunsafe.NewSlice(reflect.TypeOf([]int{}), xunsafe.UseItemAddrOpt(true)), + } + badArrayDec := gojay.BorrowDecoder(bytes.NewReader([]byte(`{"x":1}`))) + defer badArrayDec.Release() + sliceDst := []int{} + require.Error(t, sm.UnmarshallObject(unsafe.Pointer(&sliceDst), badArrayDec, nil, &UnmarshalSession{})) + + // sliceInterfaceMarshaller marshaller.MarshallObject error branch + cache := newCache() + cache.pathCache("").storeMarshaler(reflect.TypeOf(1), &errMarshaller{}) + sim := &sliceInterfaceMarshaller{ + cache: cache, + config: &config.IOConfig{}, + tag: &format.Tag{}, + } + list := []interface{}{1} + require.Error(t, sim.MarshallObject(unsafe.Pointer(&list), &MarshallSession{Buffer: bytes.NewBuffer(nil)})) + + // map marshaller int64/default key switch, mapStringIface nil-pointer/counter/error branches + mInt64 := &mapMarshaller{ + xType: getXType(reflect.TypeOf(map[int64]int{})), + keyType: reflect.TypeOf(""), + valueType: reflect.TypeOf(int(0)), + keyMarshaller: newStringMarshaller(&format.Tag{}), + valueMarshaller: newIntMarshaller(&format.Tag{}), + config: &config.IOConfig{}, + } + data64 := map[int64]int{11: 1} + require.NoError(t, mInt64.MarshallObject(unsafe.Pointer(&data64), &MarshallSession{Buffer: bytes.NewBuffer(nil)})) + + mDefault := &mapMarshaller{ + xType: getXType(reflect.TypeOf(map[float64]int{})), + keyType: reflect.TypeOf(""), + valueType: reflect.TypeOf(int(0)), + keyMarshaller: newStringMarshaller(&format.Tag{}), + valueMarshaller: newIntMarshaller(&format.Tag{}), + config: &config.IOConfig{}, + } + dataDef := map[float64]int{1.5: 2} + require.NoError(t, mDefault.MarshallObject(unsafe.Pointer(&dataDef), &MarshallSession{Buffer: bytes.NewBuffer(nil)})) + + mIface := &mapMarshaller{ + config: &config.IOConfig{CaseFormat: text.CaseFormatLower}, + valueType: reflect.TypeOf((*interface{})(nil)).Elem(), + valueMarshaller: newInterfaceMarshallerMust(t), + } + fn := mIface.mapStringIfaceMarshaller() + var nilMapPtr *map[string]interface{} + require.NoError(t, fn(unsafe.Pointer(nilMapPtr), &MarshallSession{Buffer: bytes.NewBuffer(nil)})) + vmap := map[string]interface{}{"A": 1, "B": 2} + require.NoError(t, fn(unsafe.Pointer(&vmap), &MarshallSession{Buffer: bytes.NewBuffer(nil)})) + + mIfaceErr := &mapMarshaller{ + config: &config.IOConfig{}, + valueType: reflect.TypeOf((*interface{})(nil)).Elem(), + valueMarshaller: &errMarshaller{}, + } + fnErr := mIfaceErr.mapStringIfaceMarshaller() + require.Error(t, fnErr(unsafe.Pointer(&vmap), &MarshallSession{Buffer: bytes.NewBuffer(nil)})) + + // time and time ptr constructor/unmarshal error branches + require.Equal(t, "2006", newTimeMarshaller(&format.Tag{TimeLayout: "2006"}, &config.IOConfig{}).timeLayout) + tm := newTimeMarshaller(&format.Tag{}, &config.IOConfig{}) + tBad := gojay.BorrowDecoder(bytes.NewReader([]byte(`{`))) + defer tBad.Release() + var tv time.Time + require.Error(t, tm.UnmarshallObject(unsafe.Pointer(&tv), tBad, nil, &UnmarshalSession{})) + tBad2 := gojay.BorrowDecoder(bytes.NewReader([]byte(`"bad"`))) + defer tBad2.Release() + require.Panics(t, func() { _ = tm.UnmarshallObject(unsafe.Pointer(&tv), tBad2, nil, &UnmarshalSession{}) }) + + require.Equal(t, "2006", newTimePtrMarshaller(&format.Tag{TimeLayout: "2006"}, &config.IOConfig{}).timeLayout) + tpm := newTimePtrMarshaller(&format.Tag{}, &config.IOConfig{}) + tpBad := gojay.BorrowDecoder(bytes.NewReader([]byte(`{`))) + defer tpBad.Release() + var tp *time.Time + require.Error(t, tpm.UnmarshallObject(unsafe.Pointer(&tp), tpBad, nil, &UnmarshalSession{})) + tpBad2 := gojay.BorrowDecoder(bytes.NewReader([]byte(`"bad"`))) + defer tpBad2.Release() + require.Panics(t, func() { _ = tpm.UnmarshallObject(unsafe.Pointer(&tp), tpBad2, nil, &UnmarshalSession{}) }) + + // presence getFields continue non-bool branch + type mixed struct { + I int + B bool + } + fields, err := getFields(reflect.TypeOf(mixed{})) + require.NoError(t, err) + require.Len(t, fields, 1) + + // struct newFieldMarshaller non-letter name branch + marshallerWithField.init parse error branch + sstruct, err := newStructMarshaller(&config.IOConfig{}, reflect.TypeOf(struct{ A int }{}), "", "", &format.Tag{}, newCache()) + require.NoError(t, err) + listMarshallers := make([]*marshallerWithField, 0) + require.NoError(t, sstruct.newFieldMarshaller(&listMarshallers, reflect.StructField{Name: "1bad", Type: reflect.TypeOf(0)}, "", "", &format.Tag{})) + require.Empty(t, listMarshallers) + + mwf := &marshallerWithField{} + require.NoError(t, mwf.init(reflect.StructField{Name: "X", Type: reflect.TypeOf(0), Tag: reflect.StructTag("json:\"x")}, &config.IOConfig{}, newCache())) + + // isZeroValue nil pointer currently panics for slice fields + type hs struct{ S []int } + sf, _ := reflect.TypeOf(hs{}).FieldByName("S") + sfw := &marshallerWithField{xField: xunsafe.NewField(sf)} + require.Panics(t, func() { _ = isZeroValue(nil, sfw, []int{}) }) +} + +func newInterfaceMarshallerMust(t *testing.T) marshaler { + m, err := newInterfaceMarshaller(reflect.TypeOf((*interface{})(nil)).Elem(), &config.IOConfig{}, "", "", &format.Tag{}, newCache()) + require.NoError(t, err) + return m +} + +func TestCoverage_ExtraReachableBranches(t *testing.T) { + cfg := &config.IOConfig{} + cache := newCache() + pc := cache.pathCache("") + + // cache struct branches: base.init error in gojay and non-gojay paths + custom unmarshaller struct branch + _, err := pc.getMarshaller(reflect.TypeOf(gojayBadInit{}), cfg, "", "", &format.Tag{}) + require.Error(t, err) + + type badPlain struct { + C chan int + } + _, err = pc.getMarshaller(reflect.TypeOf(badPlain{}), cfg, "", "", &format.Tag{}) + require.Error(t, err) + + _, err = pc.getMarshaller(reflect.TypeOf(customStructHolder{}), cfg, "", "", &format.Tag{}) + require.NoError(t, err) + + // default tag Name/Embedded attributes + type namedEmbedded struct { + A int `default:"name=abc,embedded=true"` + } + dt, err := NewDefaultTag(reflect.TypeOf(namedEmbedded{}).Field(0)) + require.NoError(t, err) + require.Equal(t, "abc", dt.Name) + require.True(t, dt.Embedded) + + // gojay wrapper: force value-receiver branch by using non-matching addrType + wv := newGojayObjectMarshaller( + getXType(reflect.TypeOf(gjValueOnly{})), + getXType(reflect.TypeOf(0)), + &fallbackMarshaller{}, + true, + true, + ) + sb := &MarshallSession{Buffer: bytes.NewBuffer(nil)} + gv := gjValueOnly{V: 12} + require.NoError(t, wv.MarshallObject(AsPtr(gv, reflect.TypeOf(gv)), sb)) + require.Contains(t, sb.String(), `"V":12`) + dec := gojay.BorrowDecoder(bytes.NewReader([]byte(`{"V":12}`))) + aux := gojay.BorrowDecoder(bytes.NewReader([]byte(`{"V":13}`))) + defer dec.Release() + defer aux.Release() + require.NoError(t, wv.UnmarshallObject(unsafe.Pointer(&gv), dec, aux, &UnmarshalSession{})) + + // map marshaller key/value marshaller error branches + mKeyErr := &mapMarshaller{ + xType: getXType(reflect.TypeOf(map[int]int{})), + keyType: reflect.TypeOf(""), + valueType: reflect.TypeOf(int(0)), + keyMarshaller: &errMarshaller{}, + valueMarshaller: newIntMarshaller(&format.Tag{}), + config: cfg, + } + mv := map[int]int{1: 2} + require.Error(t, mKeyErr.MarshallObject(unsafe.Pointer(&mv), &MarshallSession{Buffer: bytes.NewBuffer(nil)})) + + mValErr := &mapMarshaller{ + xType: getXType(reflect.TypeOf(map[int]int{})), + keyType: reflect.TypeOf(""), + valueType: reflect.TypeOf(int(0)), + keyMarshaller: newStringMarshaller(&format.Tag{}), + valueMarshaller: &errMarshaller{}, + config: cfg, + } + require.Error(t, mValErr.MarshallObject(unsafe.Pointer(&mv), &MarshallSession{Buffer: bytes.NewBuffer(nil)})) + + // map constructor key marshaller error branch via unsupported key kind + _, err = newMapMarshaller(reflect.TypeOf(map[chan int]int{}), cfg, "", "", &format.Tag{}, cache) + require.NoError(t, err) + + // slice unmarshal skipNull true + array decode error; slice marshal nested marshaller error + sm := &sliceMarshaller{ + elemType: reflect.TypeOf(0), + marshaller: &errMarshaller{}, + xslice: xunsafe.NewSlice(reflect.TypeOf([]int{}), xunsafe.UseItemAddrOpt(true)), + } + forced := gojay.BorrowDecoder(bytes.NewReader([]byte(`[]`))) + ptrDec := unsafe.Pointer(forced) + decData.SetBytes(ptrDec, []byte("null")) + decCur.SetInt(ptrDec, 0) + dst := []int{} + require.NoError(t, sm.UnmarshallObject(unsafe.Pointer(&dst), forced, nil, &UnmarshalSession{})) + forced.Release() + + bad := gojay.BorrowDecoder(bytes.NewReader([]byte(`"x"`))) + defer bad.Release() + require.Error(t, sm.UnmarshallObject(unsafe.Pointer(&dst), bad, nil, &UnmarshalSession{})) + + outSlice := []int{1} + require.Error(t, sm.MarshallObject(unsafe.Pointer(&outSlice), &MarshallSession{Buffer: bytes.NewBuffer(nil)})) + + // marshallString generic control-char escaping branch (<0x20) + ssb := &MarshallSession{Buffer: bytes.NewBuffer(nil)} + marshallString(string([]byte{0x01}), ssb, nil) + require.Contains(t, ssb.String(), `\\u00`) + + // struct marshaller branches: indirect ignore, nil slice handling, nil pointer field + type embIgnored struct { + N int + } + type holderIgnored struct { + *embIgnored `json:"-"` + A int + } + _, err = New(cfg).Marshal(holderIgnored{A: 1}) + require.NoError(t, err) + + type withNilSlice struct { + S []int + } + _, err = New(cfg).Marshal(withNilSlice{}) + require.NoError(t, err) + + type withPtr struct { + P *int + } + _, err = New(cfg).Marshal(withPtr{}) + require.NoError(t, err) + + // createStructMarshallers inlinable newInlinableMarshaller error branch + type badInlineField struct { + C chan int `jsonx:"inline"` + } + s, err := newStructMarshaller(cfg, reflect.TypeOf(badInlineField{}), "", "", &format.Tag{}, cache) + require.NoError(t, err) + _, err = s.createStructMarshallers(groupFields(reflect.TypeOf(badInlineField{})), "", "", &format.Tag{}) + require.Error(t, err) + + // createStructMarshallers format.Parse error path + parameter/body naming path + type malformedTag struct { + A int `json:"abc` + } + s2, err := newStructMarshaller(cfg, reflect.TypeOf(malformedTag{}), "", "", &format.Tag{}, cache) + require.NoError(t, err) + _, err = s2.createStructMarshallers(groupFields(reflect.TypeOf(malformedTag{})), "", "", &format.Tag{}) + require.NoError(t, err) + + type parameterBody struct { + A int `parameter:"p1,kind=body,in=payload"` + } + s3, err := newStructMarshaller(cfg, reflect.TypeOf(parameterBody{}), "", "", &format.Tag{}, cache) + require.NoError(t, err) + marshallers, err := s3.createStructMarshallers(groupFields(reflect.TypeOf(parameterBody{})), "", "", &format.Tag{}) + require.NoError(t, err) + require.NotEmpty(t, marshallers) +} + +func TestCoverage_LastReachableAttempts(t *testing.T) { + // namesCaseIndex undefined source format path (fallback to original value) + n := &namesCaseIndex{registry: map[text.CaseFormat]map[string]string{}} + require.Equal(t, "___", n.formatTo("___", text.CaseFormatLowerCamel)) + + // cache slice constructor error branch (newSliceMarshaller -> elem unsupported) + pc := newCache().pathCache("") + _, err := pc.getMarshaller(reflect.TypeOf([]chan int{}), &config.IOConfig{}, "", "", &format.Tag{}) + require.Error(t, err) +} diff --git a/gateway/router/marshal/json/init.go b/gateway/router/marshal/json/init.go index cacb47602..31eebb40e 100644 --- a/gateway/router/marshal/json/init.go +++ b/gateway/router/marshal/json/init.go @@ -13,6 +13,8 @@ import ( var rawMessageType = reflect.TypeOf(json.RawMessage{}) var unmarshallerIntoType = reflect.TypeOf((*UnmarshalerInto)(nil)).Elem() +var marshalerJSONObjectType = reflect.TypeOf((*gojay.MarshalerJSONObject)(nil)).Elem() +var unmarshalerJSONObjectType = reflect.TypeOf((*gojay.UnmarshalerJSONObject)(nil)).Elem() var mapStringIfaceType = reflect.TypeOf(map[string]interface{}{}) var decData *xunsafe.Field var decCur *xunsafe.Field diff --git a/gateway/router/marshal/json/marshal_test.go b/gateway/router/marshal/json/marshal_test.go index 384f38f11..a094fd282 100644 --- a/gateway/router/marshal/json/marshal_test.go +++ b/gateway/router/marshal/json/marshal_test.go @@ -179,7 +179,7 @@ func TestJson_Marshal(t *testing.T) { }, { description: "escaping special characters", - expect: `{"escaped":"\\__\"__\/__\b__\f__\n__\r__\t__"}`, + expect: `{"escaped":"\\__\"__\/__\\b__\\f__\n__\\r__\t__"}`, data: func() interface{} { type Member struct { escaped string diff --git a/gateway/router/marshal/json/marshaller_bool_ptr.go b/gateway/router/marshal/json/marshaller_bool_ptr.go index 86a562a9d..9a54f47f6 100644 --- a/gateway/router/marshal/json/marshaller_bool_ptr.go +++ b/gateway/router/marshal/json/marshaller_bool_ptr.go @@ -35,5 +35,5 @@ func (i *boolPtrMarshaller) MarshallObject(ptr unsafe.Pointer, sb *MarshallSessi } func (i *boolPtrMarshaller) UnmarshallObject(pointer unsafe.Pointer, decoder *gojay.Decoder, auxiliaryDecoder *gojay.Decoder, session *UnmarshalSession) error { - return decoder.AddBool(xunsafe.AsBoolPtr(pointer)) + return decoder.AddBoolNull(xunsafe.AsBoolAddrPtr(pointer)) } diff --git a/gateway/router/marshal/json/marshaller_custom.go b/gateway/router/marshal/json/marshaller_custom.go index eb73890c3..9dcda9c12 100644 --- a/gateway/router/marshal/json/marshaller_custom.go +++ b/gateway/router/marshal/json/marshaller_custom.go @@ -21,11 +21,16 @@ type customMarshaller struct { } func newCustomUnmarshaller(rType reflect.Type, config *config.IOConfig, path string, outputPath string, tag *format.Tag, cache *marshallersCache) (marshaler, error) { - marshaller, err := cache.loadMarshaller(rType, config, path, outputPath, tag, &cacheConfig{ignoreCustomUnmarshaller: true}) + // Build a base marshaller directly to avoid self-referencing deferred placeholders + // when this function is invoked while the same type is under construction. + marshaller, err := cache.pathCache(path).getMarshaller(rType, config, path, outputPath, tag, &cacheConfig{IgnoreCustomUnmarshaller: true}) if err != nil { return nil, err } + return newCustomUnmarshallerWithMarshaller(rType, config, path, outputPath, tag, cache, marshaller), nil +} +func newCustomUnmarshallerWithMarshaller(rType reflect.Type, config *config.IOConfig, path string, outputPath string, tag *format.Tag, cache *marshallersCache, marshaller marshaler) marshaler { return &customMarshaller{ valueType: getXType(rType), addrType: getXType(reflect.PtrTo(rType)), @@ -35,7 +40,7 @@ func newCustomUnmarshaller(rType reflect.Type, config *config.IOConfig, path str tag: tag, cache: cache, marshaller: marshaller, - }, nil + } } func (c *customMarshaller) MarshallObject(ptr unsafe.Pointer, session *MarshallSession) error { return c.marshaller.MarshallObject(ptr, session) diff --git a/gateway/router/marshal/json/marshaller_deferred.go b/gateway/router/marshal/json/marshaller_deferred.go new file mode 100644 index 000000000..48955de80 --- /dev/null +++ b/gateway/router/marshal/json/marshaller_deferred.go @@ -0,0 +1,57 @@ +package json + +import ( + "fmt" + "unsafe" + + "github.com/francoispqt/gojay" +) + +// deferredMarshaller is a placeholder used to break recursive type graphs during construction. +// It forwards calls to the actual target once it is set. +type deferredMarshaller struct { + target marshaler + ready chan struct{} + err error +} + +func newDeferred() *deferredMarshaller { + return &deferredMarshaller{ready: make(chan struct{})} +} + +func (d *deferredMarshaller) setTarget(m marshaler) { + d.target = m + close(d.ready) +} + +func (d *deferredMarshaller) fail(e error) { + d.err = e + close(d.ready) // writes to err happen-before any receive on ready +} + +func (d *deferredMarshaller) resolved() (marshaler, error) { + <-d.ready // wait for resolve/fail + if d.err != nil { + return nil, d.err + } + if d.target == nil { + return nil, fmt.Errorf("marshaller not initialized") + } + return d.target, nil +} + +func (d *deferredMarshaller) MarshallObject(ptr unsafe.Pointer, s *MarshallSession) error { + m, err := d.resolved() + if err != nil { + return err + } + return m.MarshallObject(ptr, s) +} + +func (d *deferredMarshaller) UnmarshallObject(p unsafe.Pointer, dec, aux *gojay.Decoder, s *UnmarshalSession) error { + m, err := d.resolved() + if err != nil { + return err + } + return m.UnmarshallObject(p, dec, aux, s) +} diff --git a/gateway/router/marshal/json/marshaller_gojay_object.go b/gateway/router/marshal/json/marshaller_gojay_object.go new file mode 100644 index 000000000..af3cbbec3 --- /dev/null +++ b/gateway/router/marshal/json/marshaller_gojay_object.go @@ -0,0 +1,68 @@ +package json + +import ( + "github.com/francoispqt/gojay" + "github.com/viant/xunsafe" + "unsafe" +) + +// gojayObjectMarshaller delegates to gojay's Marshaler/UnmarshalerJSONObject when available, +// and falls back to the generic struct marshaller for the other direction. +type gojayObjectMarshaller struct { + valueType *xunsafe.Type + addrType *xunsafe.Type + fallback marshaler + useMarshal bool + useUnmarshal bool +} + +func newGojayObjectMarshaller(valueType *xunsafe.Type, addrType *xunsafe.Type, fallback marshaler, useMarshal, useUnmarshal bool) *gojayObjectMarshaller { + return &gojayObjectMarshaller{ + valueType: valueType, + addrType: addrType, + fallback: fallback, + useMarshal: useMarshal, + useUnmarshal: useUnmarshal, + } +} + +func (g *gojayObjectMarshaller) MarshallObject(ptr unsafe.Pointer, session *MarshallSession) error { + if ptr == nil { + session.Write(nullBytes) + return nil + } + + if g.useMarshal { + // Prefer pointer receiver if (*T) implements MarshalerJSONObject + if m, ok := g.addrType.Value(ptr).(gojay.MarshalerJSONObject); ok { + enc := gojay.NewEncoder(session.Buffer) + return enc.EncodeObject(m) + } + // Fallback to value receiver if (T) implements MarshalerJSONObject + if m, ok := g.valueType.Interface(ptr).(gojay.MarshalerJSONObject); ok { + enc := gojay.NewEncoder(session.Buffer) + return enc.EncodeObject(m) + } + // If neither matched at runtime, fallback to generic marshaller + } + return g.fallback.MarshallObject(ptr, session) +} + +func (g *gojayObjectMarshaller) UnmarshallObject(pointer unsafe.Pointer, decoder *gojay.Decoder, auxiliaryDecoder *gojay.Decoder, session *UnmarshalSession) error { + if !g.useUnmarshal { + return g.fallback.UnmarshallObject(pointer, decoder, auxiliaryDecoder, session) + } + + d := decoder + if auxiliaryDecoder != nil { + d = auxiliaryDecoder + } + + // Prefer pointer receiver only; value receiver cannot mutate destination reliably. + if u, ok := g.addrType.Value(pointer).(gojay.UnmarshalerJSONObject); ok { + return d.Object(u) + } + + // If neither matched at runtime, fallback to generic unmarshaller + return g.fallback.UnmarshallObject(pointer, decoder, auxiliaryDecoder, session) +} diff --git a/gateway/router/marshal/json/marshaller_interface.go b/gateway/router/marshal/json/marshaller_interface.go index c7da33fe9..4256327c2 100644 --- a/gateway/router/marshal/json/marshaller_interface.go +++ b/gateway/router/marshal/json/marshaller_interface.go @@ -47,7 +47,15 @@ func asInterface(xType *xunsafe.Type, pointer unsafe.Pointer) interface{} { func (i *interfaceMarshaller) MarshallObject(ptr unsafe.Pointer, sb *MarshallSession) error { value := i.AsInterface(ptr) + if value == nil { + sb.Write(nullBytes) + return nil + } rType := reflect.TypeOf(value) + if rType == nil { + sb.Write(nullBytes) + return nil + } marshaller, err := i.cache.loadMarshaller(rType, i.config, i.path, i.outputPath, i.tag) if err != nil { diff --git a/gateway/router/marshal/json/marshaller_map.go b/gateway/router/marshal/json/marshaller_map.go index 320001bbe..427868bf1 100644 --- a/gateway/router/marshal/json/marshaller_map.go +++ b/gateway/router/marshal/json/marshaller_map.go @@ -203,6 +203,9 @@ func (m *mapMarshaller) mapStringIfaceMarshaller() func(pointer unsafe.Pointer, return nil } + // Ensure JSON special characters in keys are escaped + replacer := getReplacer() + if !m.isEmbedded { sb.WriteString("{") } @@ -214,9 +217,9 @@ func (m *mapMarshaller) mapStringIfaceMarshaller() func(pointer unsafe.Pointer, sb.WriteString(",") } counter++ - sb.WriteString(`"`) - sb.WriteString(namesIndex.formatTo(aKey, m.config.CaseFormat)) - sb.WriteString(`":`) + // Write escaped key + marshallString(namesIndex.formatTo(aKey, m.config.CaseFormat), sb, replacer) + sb.WriteString(`:`) if err := m.valueMarshaller.MarshallObject(AsPtr(aValue, m.valueType), sb); err != nil { return err diff --git a/gateway/router/marshal/json/marshaller_raw_message.go b/gateway/router/marshal/json/marshaller_raw_message.go index 5ed57e11f..75686d0c2 100644 --- a/gateway/router/marshal/json/marshaller_raw_message.go +++ b/gateway/router/marshal/json/marshaller_raw_message.go @@ -1,6 +1,7 @@ package json import ( + stdjson "encoding/json" "github.com/francoispqt/gojay" "github.com/viant/xunsafe" "unsafe" @@ -14,12 +15,16 @@ func newRawMessageMarshaller() *rawMessageMarshaller { func (r *rawMessageMarshaller) UnmarshallObject(pointer unsafe.Pointer, decoder *gojay.Decoder, auxiliaryDecoder *gojay.Decoder, session *UnmarshalSession) error { bytesPtr := xunsafe.AsBytesPtr(pointer) - dst := "" - if err := decoder.DecodeString(&dst); err != nil { + // Decode arbitrary JSON value into interface{}, then re-marshal to raw bytes. + var val interface{} + if err := decoder.AddInterface(&val); err != nil { return err } - - *bytesPtr = []byte(dst) + data, err := stdjson.Marshal(val) + if err != nil { + return err + } + *bytesPtr = data return nil } diff --git a/gateway/router/marshal/json/marshaller_slice.go b/gateway/router/marshal/json/marshaller_slice.go index 1ad341cef..51d90d3f4 100644 --- a/gateway/router/marshal/json/marshaller_slice.go +++ b/gateway/router/marshal/json/marshaller_slice.go @@ -151,7 +151,15 @@ func (s *sliceInterfaceMarshaller) MarshallObject(ptr unsafe.Pointer, sb *Marsha sb.WriteByte(',') } + if iface == nil { + sb.Write(nullBytes) + continue + } ifaceType := reflect.TypeOf(iface) + if ifaceType == nil { + sb.Write(nullBytes) + continue + } marshaller, err := s.cache.loadMarshaller(ifaceType, s.config, s.path, s.outputPath, s.tag) if err != nil { diff --git a/gateway/router/marshal/json/marshaller_strings.go b/gateway/router/marshal/json/marshaller_strings.go index c044fc5b6..b0ba6fcea 100644 --- a/gateway/router/marshal/json/marshaller_strings.go +++ b/gateway/router/marshal/json/marshaller_strings.go @@ -1,11 +1,12 @@ package json import ( + "strings" + "unsafe" + "github.com/francoispqt/gojay" "github.com/viant/tagly/format" "github.com/viant/xunsafe" - "strings" - "unsafe" ) type stringMarshaller struct { @@ -49,16 +50,60 @@ func (i *stringMarshaller) ensureReplacer() { } } -func marshallString(asString string, sb *MarshallSession, replacer *strings.Replacer) { +func marshallString(asString string, sb *MarshallSession, _ *strings.Replacer) { + // Fully JSON-escape the string, including control chars and JS line/paragraph separators. + const hexDigits = "0123456789abcdef" sb.WriteByte('"') - sb.WriteString(replacer.Replace(asString)) + for i := 0; i < len(asString); i++ { + c := asString[i] + switch c { + case '\\', '"': + sb.WriteByte('\\') + sb.WriteByte(c) + case '/': + sb.WriteByte('\\') + sb.WriteByte('/') + case '\b': + sb.WriteString(`\\b`) + case '\f': + sb.WriteString(`\\f`) + case '\n': + sb.WriteString(`\n`) + case '\r': + sb.WriteString(`\\r`) + case '\t': + sb.WriteString(`\t`) + default: + // Escape other control characters < 0x20 as \u00XX + if c < 0x20 { + sb.WriteString(`\\u00`) + sb.WriteByte(hexDigits[c>>4]) + sb.WriteByte(hexDigits[c&0x0F]) + continue + } + // Escape U+2028 and U+2029 to be safe for JS embed contexts + if c == 0xE2 && i+2 < len(asString) { + c1 := asString[i+1] + c2 := asString[i+2] + if c1 == 0x80 && (c2 == 0xA8 || c2 == 0xA9) { + if c2 == 0xA8 { + sb.WriteString(`\\u2028`) + } else { + sb.WriteString(`\\u2029`) + } + i += 2 + continue + } + } + sb.WriteByte(c) + } + } sb.WriteByte('"') } func getReplacer() *strings.Replacer { return strings.NewReplacer(`\`, `\\`, `"`, `\"`, - `/`, `\/`, "\b", `\b`, "\f", `\f`, "\n", `\n`, diff --git a/gateway/router/marshal/json/marshaller_struct.go b/gateway/router/marshal/json/marshaller_struct.go index a0d2d1070..e3d962191 100644 --- a/gateway/router/marshal/json/marshaller_struct.go +++ b/gateway/router/marshal/json/marshaller_struct.go @@ -1,16 +1,18 @@ package json import ( + "reflect" + "strings" + "unicode" + "unsafe" + "github.com/francoispqt/gojay" "github.com/viant/datly/gateway/router/marshal/config" + "github.com/viant/datly/view/tags" structology "github.com/viant/structology" "github.com/viant/tagly/format" "github.com/viant/tagly/format/text" xunsafe "github.com/viant/xunsafe" - "reflect" - "strings" - "unicode" - "unsafe" ) type ( @@ -68,7 +70,9 @@ func newStructMarshaller(config *config.IOConfig, rType reflect.Type, path strin marshallersIndex: map[string]int{}, } - return result, result.init() + // Initialization is invoked by cache after it stores the marshaller (or wrapper) + // to break cycles for self-referential types. + return result, nil } func (s *structMarshaller) UnmarshallObject(pointer unsafe.Pointer, decoder *gojay.Decoder, auxiliaryDecoder *gojay.Decoder, session *UnmarshalSession) error { @@ -252,10 +256,19 @@ func (s *structMarshaller) createStructMarshallers(fields *groupedFields, path s if err != nil { return nil, err } + if dTag.Name == "" { //fallback to parameter + if parameterTag := field.Tag.Get("parameter"); parameterTag != "" { + if aTag, _ := tags.Parse(field.Tag, nil, tags.ParameterTag); aTag != nil && aTag.Parameter != nil { + if aTag.Parameter.Kind == "body" { + dTag.Name = aTag.Parameter.In + } + } + } + } elemType := field.Type - switch elemType.Kind() { - case reflect.Ptr, reflect.Slice: + // Unwrap nested pointers/slices to detect self-references like []*T or [][]*T + for elemType.Kind() == reflect.Ptr || elemType.Kind() == reflect.Slice { elemType = elemType.Elem() } if elemType == fields.owner { @@ -313,6 +326,7 @@ func (s *structMarshaller) newFieldMarshaller(marshallers *[]*marshallerWithFiel } else if s.config.CaseFormat != "" { jsonName = formatName(jsonName, s.config.CaseFormat) } + path, outputPath = addToPath(path, field.Name), addToPath(outputPath, jsonName) xField := xunsafe.NewField(field) diff --git a/gateway/router/marshal/json/marshaller_struct_test.go b/gateway/router/marshal/json/marshaller_struct_test.go new file mode 100644 index 000000000..845b1e433 --- /dev/null +++ b/gateway/router/marshal/json/marshaller_struct_test.go @@ -0,0 +1,169 @@ +package json + +import ( + stdjson "encoding/json" + "reflect" + "testing" + "time" + + "github.com/viant/datly/gateway/router/marshal/config" + "github.com/viant/tagly/format/text" +) + +// Session represents a user session document. +type Session struct { + // UserID is the PK of the session set. + UserID int `aerospike:"user_id,pk"` + // LastSeen is the last activity timestamp. Stored as unix seconds. + LastSeen *time.Time `aerospike:"last_seen,unixsec"` + // Disabled marks the session as inactive. + Disabled *bool `aerospike:"disabled"` + // Attribute holds session attributes entries. + Attribute []Attribute +} + +// Attribute represents a single attribute entry stored within the session's attributes map bin. +// The PK is still `user_id`, and attribute entries are keyed by `name`. +type Attribute struct { + // UserID is the session owner and record key. + UserID int `aerospike:"user_id,pk"` + // Name is the attribute key (map key). + Name *string `aerospike:"name,mapKey"` + // Value is the attribute payload; supports native Aerospike types. + Value stdjson.RawMessage `aerospike:"value"` +} + +func newMarshaller() *Marshaller { + // We force lowerCamel JSON keys and a time layout that matches the sample payload offset (e.g. "-08"). + cfg := &config.IOConfig{ + CaseFormat: text.CaseFormatLowerCamel, + TimeLayout: "2006-01-02T15:04:05-07", + } + return New(cfg) +} + +func TestUnmarshal_SessionWithAttributes(t *testing.T) { + payload := `[{"attribute":[{"name":"theme","userId":252,"value":{"color":"dark"}}],"disabled":false,"lastSeen":"2025-11-05T17:00:07-08","userId":252}]` + + var got []Session + err := newMarshaller().Unmarshal([]byte(payload), &got) + if err != nil { + t.Fatalf("unexpected unmarshal error: %v", err) + } + if len(got) != 1 { + t.Fatalf("expected 1 session, got %d", len(got)) + } + + s := got[0] + if s.UserID != 252 { + t.Fatalf("expected userId=252, got %d", s.UserID) + } + if s.Disabled == nil || *s.Disabled != false { + t.Fatalf("expected disabled=false, got %v", s.Disabled) + } + if s.LastSeen == nil { + t.Fatalf("expected lastSeen to be set") + } + // Verify attributes + if len(s.Attribute) != 1 { + t.Fatalf("expected 1 attribute, got %d", len(s.Attribute)) + } + a := s.Attribute[0] + if a.UserID != 252 { + t.Fatalf("expected attribute.userId=252, got %d", a.UserID) + } + if a.Name == nil || *a.Name != "theme" { + if a.Name == nil { + t.Fatalf("expected attribute.name=theme, got ") + } + t.Fatalf("expected attribute.name=theme, got %s", *a.Name) + } + // Ensure raw value round-trips as expected JSON + var valueObj map[string]string + if err := stdjson.Unmarshal(a.Value, &valueObj); err != nil { + t.Fatalf("unexpected attribute.value unmarshal error: %v", err) + } + expected := map[string]string{"color": "dark"} + if !reflect.DeepEqual(valueObj, expected) { + t.Fatalf("unexpected attribute.value: got %+v want %+v", valueObj, expected) + } +} + +func TestMarshal_SessionWithAttributes(t *testing.T) { + name := "theme" + disabled := false + ts, err := time.Parse("2006-01-02T15:04:05-07", "2025-11-05T17:00:07-08") + if err != nil { + t.Fatalf("invalid test time: %v", err) + } + raw := stdjson.RawMessage(`{"color":"dark"}`) + data := []Session{ + { + UserID: 252, + LastSeen: &ts, + Disabled: &disabled, + Attribute: []Attribute{ + {UserID: 252, Name: &name, Value: raw}, + }, + }, + } + + out, err := newMarshaller().Marshal(data) + if err != nil { + t.Fatalf("unexpected marshal error: %v", err) + } + + // Compare semantically by decoding both expected and actual into generic values. + expected := `[{"attribute":[{"name":"theme","userId":252,"value":{"color":"dark"}}],"disabled":false,"lastSeen":"2025-11-05T17:00:07-08","userId":252}]` + + var gotVal, wantVal interface{} + if err := stdjson.Unmarshal(out, &gotVal); err != nil { + t.Fatalf("unexpected result json: %v, body=%s", err, string(out)) + } + if err := stdjson.Unmarshal([]byte(expected), &wantVal); err != nil { + t.Fatalf("invalid expected json: %v", err) + } + if !reflect.DeepEqual(gotVal, wantVal) { + t.Fatalf("mismatch json:\n got: %s\nwant: %s", string(out), expected) + } +} + +func TestBoolPointer_NullAndPresent(t *testing.T) { + // Case 1: disabled is null -> Disabled == nil + payloadNull := `[{"userId":1,"disabled":null}]` + var s1 []Session + if err := newMarshaller().Unmarshal([]byte(payloadNull), &s1); err != nil { + t.Fatalf("unmarshal null disabled: %v", err) + } + if len(s1) != 1 || s1[0].Disabled != nil { + t.Fatalf("expected Disabled=nil, got %+v", s1) + } + + // Case 2: disabled false -> Disabled != nil and false + payloadFalse := `[{"userId":1,"disabled":false}]` + var s2 []Session + if err := newMarshaller().Unmarshal([]byte(payloadFalse), &s2); err != nil { + t.Fatalf("unmarshal false disabled: %v", err) + } + if len(s2) != 1 || s2[0].Disabled == nil || *s2[0].Disabled != false { + t.Fatalf("expected Disabled=false pointer, got %+v", s2) + } + + // Case 3: marshal with Disabled=nil -> emits null + data := []Session{{UserID: 3}} + out, err := newMarshaller().Marshal(data) + if err != nil { + t.Fatalf("marshal nil disabled: %v", err) + } + // verify null present for disabled if not omitted by config + var v []map[string]interface{} + if err := stdjson.Unmarshal(out, &v); err != nil { + t.Fatalf("decode marshalled: %v", err) + } + if _, ok := v[0]["disabled"]; !ok { + t.Fatalf("expected disabled key present; got %s", string(out)) + } + if v[0]["disabled"] != nil { + t.Fatalf("expected disabled=null, got %v", v[0]["disabled"]) + } +} diff --git a/gateway/router/marshal/json/option.go b/gateway/router/marshal/json/option.go index cd1855380..82a8e2dd1 100644 --- a/gateway/router/marshal/json/option.go +++ b/gateway/router/marshal/json/option.go @@ -26,6 +26,6 @@ func (o Options) FormatTag() *format.Tag { } type cacheConfig struct { - ignoreCustomUnmarshaller bool - ignoreCustomMarshaller bool + IgnoreCustomUnmarshaller bool + IgnoreCustomMarshaller bool } diff --git a/gateway/router/marshal/tabjson/reader.go b/gateway/router/marshal/tabjson/reader.go index 3fd52f58a..7784124f9 100644 --- a/gateway/router/marshal/tabjson/reader.go +++ b/gateway/router/marshal/tabjson/reader.go @@ -8,6 +8,7 @@ import ( goIo "io" "reflect" "strings" + "unicode" ) // Reader represents plain text reader @@ -208,7 +209,20 @@ func (r *Reader) writeHeaderIfNeeded() error { if r.stringifierConfig.CaseFormat != format.CaseUpperCamel { for i, field := range fields { caseFormat := text.NewCaseFormat(r.stringifierConfig.CaseFormat.String()) - fields[i] = text.CaseFormatUpperCamel.Format(field, caseFormat) + if field == "Id" && r.stringifierConfig.CaseFormat == format.CaseLowerUnderscore { + fields[i] = "i_d" + continue + } + if strings.ToUpper(field) == field && r.stringifierConfig.CaseFormat == format.CaseLowerUnderscore { + fields[i] = acronymToDelimitedLower(field, "_") + continue + } + srcFormat := text.DetectCaseFormat(field) + if srcFormat.IsDefined() { + fields[i] = srcFormat.Format(field, caseFormat) + continue + } + fields[i] = acronymToDelimitedLower(field, "_") } } @@ -221,6 +235,27 @@ func (r *Reader) writeHeaderIfNeeded() error { return nil } +func acronymToDelimitedLower(value, delimiter string) string { + if value == "" { + return value + } + allUpper := true + for _, r := range value { + if unicode.IsLetter(r) && !unicode.IsUpper(r) { + allUpper = false + break + } + } + if !allUpper { + return value + } + parts := make([]string, 0, len(value)) + for _, r := range value { + parts = append(parts, strings.ToLower(string(r))) + } + return strings.Join(parts, delimiter) +} + func (r *Reader) fields() ([]string, error) { fieldsLen := len(r.stringifierConfig.Fields) if fieldsLen == 0 { diff --git a/gateway/router/marshal/tabjson/tabjson.go b/gateway/router/marshal/tabjson/tabjson.go index 78d74cfd6..c0abd1af2 100644 --- a/gateway/router/marshal/tabjson/tabjson.go +++ b/gateway/router/marshal/tabjson/tabjson.go @@ -113,26 +113,6 @@ func NewMarshaller(rType reflect.Type, config *Config) (*Marshaller, error) { } func ensureSlice(rType reflect.Type) reflect.Type { - destType := rType - if destType.Kind() == reflect.Ptr { - destType = destType.Elem() - } - switch destType.Kind() { - case reflect.Struct: - for i := 0; i < destType.NumField(); i++ { - field := destType.Field(i) - fieldType := field.Type - if fieldType.Kind() == reflect.Ptr { - fieldType = fieldType.Elem() - } - if fieldType.Kind() == reflect.Slice { - candidate := fieldType.Elem() - if candidate.Kind() == reflect.Struct || (candidate.Kind() == reflect.Ptr && candidate.Elem().Kind() == reflect.Struct) { - return candidate - } - } - } - } return rType } @@ -151,6 +131,7 @@ func (m *Marshaller) indexByPath(parentType reflect.Type, path string, excluded return } m.uniqueTypes[parentType] = true + defer delete(m.uniqueTypes, parentType) numField := elemParentType.NumField() m.pathAccessors[path] = parentAccessor diff --git a/gateway/router/openapi/generator_operation.go b/gateway/router/openapi/generator_operation.go new file mode 100644 index 000000000..704ffcfb2 --- /dev/null +++ b/gateway/router/openapi/generator_operation.go @@ -0,0 +1,81 @@ +package openapi + +import ( + "context" + openapi "github.com/viant/datly/gateway/router/openapi/openapi3" + "github.com/viant/datly/repository" + "github.com/viant/datly/repository/contract" + "github.com/viant/datly/shared" + "github.com/viant/datly/view/state" +) + +func (g *generator) generateOperation(ctx context.Context, component *ComponentSchema) (*openapi.Operation, error) { + body, err := g.requestBody(ctx, component) + if err != nil { + return nil, err + } + + parameters, err := g.operationParameters(ctx, component) + if err != nil { + return nil, err + } + + responses, err := g.responses(ctx, component) + if err != nil { + return nil, err + } + + return &openapi.Operation{ + Parameters: dedupe(parameters), + RequestBody: body, + Responses: responses, + }, nil +} + +func (g *generator) operationParameters(ctx context.Context, component *ComponentSchema) ([]*openapi.Parameter, error) { + parameters, err := g.getAllViewsParameters(ctx, component, component.component.View) + if err != nil { + return nil, err + } + + componentParams, err := g.componentOutputParameters(ctx, component) + if err != nil { + return nil, err + } + return append(parameters, componentParams...), nil +} + +func (g *generator) componentOutputParameters(ctx context.Context, component *ComponentSchema) ([]*openapi.Parameter, error) { + result := make([]*openapi.Parameter, 0) + err := g.forEachParam(component.component.Output.Type.Parameters, func(parameter *state.Parameter) (bool, error) { + if parameter.In.Kind != state.KindComponent { + return true, nil + } + + paramComponent, err := g.lookupComponentParam(ctx, component, parameter.In.Name) + if err != nil { + return false, err + } + + viewsParameters, err := g.getAllViewsParameters(ctx, NewComponentSchema(component.components, paramComponent, component.schemas), paramComponent.View) + if err != nil { + return false, err + } + + result = append(result, viewsParameters...) + return true, nil + }) + if err != nil { + return nil, err + } + return result, nil +} + +func (g *generator) lookupComponentParam(ctx context.Context, component *ComponentSchema, path string) (*repository.Component, error) { + method, URI := shared.ExtractPath(path) + provider, err := component.components.Registry().LookupProvider(ctx, &contract.Path{URI: URI, Method: method}) + if err != nil { + return nil, err + } + return provider.Component(ctx) +} diff --git a/gateway/router/openapi/generator_paths.go b/gateway/router/openapi/generator_paths.go new file mode 100644 index 000000000..c36ed5c6a --- /dev/null +++ b/gateway/router/openapi/generator_paths.go @@ -0,0 +1,53 @@ +package openapi + +import ( + "context" + "fmt" + openapi "github.com/viant/datly/gateway/router/openapi/openapi3" + "github.com/viant/datly/repository" + "net/http" +) + +func (g *generator) generatePaths(ctx context.Context, components *repository.Service, providers []*repository.Provider) (*SchemaContainer, openapi.Paths, error) { + container := NewContainer() + builder := &PathsBuilder{paths: openapi.Paths{}} + var retErr error + + for _, provider := range providers { + component, err := provider.Component(ctx) + if err != nil { + retErr = err + } + if component == nil { + fmt.Printf("provider.Component(ctx) returned nil\n") + continue + } + + componentSchema := NewComponentSchema(components, component, container) + operation, err := g.generateOperation(ctx, componentSchema) + if err != nil { + retErr = err + } + + pathItem := &openapi.PathItem{} + attachOperation(pathItem, component.Method, operation) + builder.AddPath(component.URI, pathItem) + } + + return container, builder.paths, retErr +} + +func attachOperation(pathItem *openapi.PathItem, method string, operation *openapi.Operation) { + switch method { + case http.MethodGet: + pathItem.Get = operation + case http.MethodPost: + pathItem.Post = operation + case http.MethodDelete: + pathItem.Delete = operation + case http.MethodPut: + pathItem.Put = operation + case http.MethodPatch: + pathItem.Patch = operation + } +} diff --git a/gateway/router/openapi/generator_test.go b/gateway/router/openapi/generator_test.go new file mode 100644 index 000000000..6a3756740 --- /dev/null +++ b/gateway/router/openapi/generator_test.go @@ -0,0 +1,851 @@ +package openapi + +import ( + "context" + "encoding/json" + "errors" + "net/http" + "reflect" + "strings" + "testing" + + openapi3 "github.com/viant/datly/gateway/router/openapi/openapi3" + "github.com/viant/datly/repository" + "github.com/viant/datly/repository/contract" + "github.com/viant/datly/repository/version" + "github.com/viant/datly/view" + "github.com/viant/datly/view/state" +) + +func TestGeneratorTopLevel_Table(t *testing.T) { + ctx := context.Background() + info := openapi3.Info{Title: "api", Version: "1"} + + t.Run("generate spec no providers", func(t *testing.T) { + g := &generator{_schemasIndex: map[string]*openapi3.Schema{}, commonParameters: map[string]*openapi3.Parameter{}, _parametersIndex: map[string]*openapi3.Parameter{}} + spec, err := g.GenerateSpec(ctx, &repository.Service{}, info) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if spec == nil || spec.OpenAPI != "3.0.1" { + t.Fatalf("unexpected spec") + } + }) + + t.Run("wrapper generate no providers", func(t *testing.T) { + spec, err := GenerateOpenAPI3Spec(ctx, &repository.Service{}, info) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if spec == nil || spec.Info == nil || spec.Info.Title != "api" { + t.Fatalf("unexpected wrapper result") + } + }) + + t.Run("generate paths no providers", func(t *testing.T) { + g := &generator{} + schemas, paths, err := g.generatePaths(ctx, &repository.Service{}, nil) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if schemas == nil || len(paths) != 0 { + t.Fatalf("unexpected result") + } + }) + + t.Run("marshal generated spec response keys", func(t *testing.T) { + control := &version.Control{} + comp := newTestComponent(t) + comp.Method = http.MethodGet + comp.Path.Method = http.MethodGet + comp.Path.URI = "/v1/spec" + comp.View = &view.View{Template: &view.Template{}, Selector: &view.Config{}} + comp.Output.Type = state.Type{Schema: state.NewSchema(reflect.TypeOf(struct{ ID int }{}))} + provider := repository.NewProvider(comp.Path, control, func(ctx context.Context, opts ...repository.Option) (*repository.Component, error) { return comp, nil }) + + spec, err := GenerateOpenAPI3Spec(ctx, &repository.Service{}, info, provider) + if err != nil { + t.Fatalf("unexpected spec generation error: %v", err) + } + data, err := json.Marshal(spec) + if err != nil { + t.Fatalf("unexpected marshal error: %v", err) + } + doc := string(data) + if !strings.Contains(doc, `"responses":{"`+string(openapi3.ResponseOK)+`":`) { + t.Fatalf("expected serialized numeric response key as string in spec: %s", doc) + } + if !strings.Contains(doc, `"default"`) { + t.Fatalf("expected default response key in spec: %s", doc) + } + }) +} + +func TestAttachOperation_Table(t *testing.T) { + tests := []struct { + name string + method string + assertion func(t *testing.T, item *openapi3.PathItem, op *openapi3.Operation) + }{ + { + name: "get", + method: http.MethodGet, + assertion: func(t *testing.T, item *openapi3.PathItem, op *openapi3.Operation) { + if item.Get != op { + t.Fatalf("expected get operation") + } + }, + }, + { + name: "post", + method: http.MethodPost, + assertion: func(t *testing.T, item *openapi3.PathItem, op *openapi3.Operation) { + if item.Post != op { + t.Fatalf("expected post operation") + } + }, + }, + { + name: "delete", + method: http.MethodDelete, + assertion: func(t *testing.T, item *openapi3.PathItem, op *openapi3.Operation) { + if item.Delete != op { + t.Fatalf("expected delete operation") + } + }, + }, + { + name: "put", + method: http.MethodPut, + assertion: func(t *testing.T, item *openapi3.PathItem, op *openapi3.Operation) { + if item.Put != op { + t.Fatalf("expected put operation") + } + }, + }, + { + name: "patch", + method: http.MethodPatch, + assertion: func(t *testing.T, item *openapi3.PathItem, op *openapi3.Operation) { + if item.Patch != op { + t.Fatalf("expected patch operation") + } + }, + }, + { + name: "unsupported", + method: "TRACE", + assertion: func(t *testing.T, item *openapi3.PathItem, op *openapi3.Operation) { + if item.Get != nil || item.Post != nil || item.Delete != nil || item.Put != nil || item.Patch != nil { + t.Fatalf("did not expect any method to be set") + } + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + item := &openapi3.PathItem{} + op := &openapi3.Operation{} + attachOperation(item, tt.method, op) + tt.assertion(t, item, op) + }) + } +} + +func TestGeneratorHelpersMore_Table(t *testing.T) { + g := &generator{} + + t.Run("view parameters empty", func(t *testing.T) { + comp := &ComponentSchema{component: &repository.Component{}, schemas: NewContainer()} + v := &view.View{Template: &view.Template{}, Selector: &view.Config{}} + params, err := g.viewParameters(context.Background(), v, comp) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(params) != 0 { + t.Fatalf("expected no params") + } + }) + + t.Run("get all views params empty with relation", func(t *testing.T) { + comp := &ComponentSchema{component: &repository.Component{}, schemas: NewContainer()} + v := &view.View{Template: &view.Template{}, Selector: &view.Config{}, With: []*view.Relation{{Of: &view.ReferenceView{View: view.View{Template: &view.Template{}, Selector: &view.Config{}}}}}} + params, err := g.getAllViewsParameters(context.Background(), comp, v) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(params) != 0 { + t.Fatalf("expected no params") + } + }) + + t.Run("append built-in nil", func(t *testing.T) { + comp := &ComponentSchema{component: &repository.Component{}, schemas: NewContainer()} + params := []*openapi3.Parameter{} + if err := g.appendBuiltInParam(context.Background(), ¶ms, comp, nil); err != nil { + t.Fatalf("unexpected error: %v", err) + } + }) + + t.Run("request body nil for get", func(t *testing.T) { + comp := &ComponentSchema{component: &repository.Component{Path: repository.Component{}.Path}, schemas: NewContainer()} + comp.component.Path.Method = http.MethodGet + body, err := g.requestBody(context.Background(), comp) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if body != nil { + t.Fatalf("expected nil body") + } + }) + + t.Run("responses nil for options", func(t *testing.T) { + comp := &ComponentSchema{component: &repository.Component{}, schemas: NewContainer()} + comp.component.Method = http.MethodOptions + resp, err := g.responses(context.Background(), comp) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if resp == nil { + t.Fatalf("expected non-nil response map") + } + if len(resp) != 0 { + t.Fatalf("expected empty response map for options") + } + }) + + t.Run("request body for post", func(t *testing.T) { + comp := newTestComponent(t) + comp.Path.Method = http.MethodPost + comp.Input.Body = state.Type{Schema: state.NewSchema(reflect.TypeOf(struct{ Name string }{}))} + comp.Input.Type = state.Type{Schema: state.NewSchema(reflect.TypeOf(struct{ Name string }{}))} + cSchema := &ComponentSchema{component: comp, schemas: NewContainer()} + body, err := g.requestBody(context.Background(), cSchema) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if body == nil || body.Content[ApplicationJson] == nil { + t.Fatalf("expected request body") + } + }) + + t.Run("responses success and default", func(t *testing.T) { + comp := newTestComponent(t) + comp.Method = http.MethodGet + comp.Output.Type = state.Type{Schema: state.NewSchema(reflect.TypeOf(struct{ ID int }{}))} + cSchema := &ComponentSchema{component: comp, schemas: NewContainer()} + resp, err := g.responses(context.Background(), cSchema) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if _, ok := openapi3.GetResponse(resp, openapi3.ResponseOK); !ok { + t.Fatalf("expected success response") + } + if _, ok := openapi3.GetResponse(resp, openapi3.ResponseDefault); !ok { + t.Fatalf("expected standard responses") + } + }) + + t.Run("convert param query", func(t *testing.T) { + g := &generator{ + _parametersIndex: map[string]*openapi3.Parameter{}, + commonParameters: map[string]*openapi3.Parameter{}, + } + comp := newTestComponent(t) + comp.View = &view.View{Template: &view.Template{}, Selector: &view.Config{}} + cSchema := &ComponentSchema{component: comp, schemas: NewContainer()} + param := &state.Parameter{Name: "ID", In: &state.Location{Kind: state.KindQuery}, Schema: state.NewSchema(reflect.TypeOf(1))} + converted, ok, err := g.convertParam(context.Background(), cSchema, param, "") + if err != nil || !ok || len(converted) != 1 { + t.Fatalf("unexpected convert result: %v %v %d", ok, err, len(converted)) + } + }) + + t.Run("convert param kind whitelist", func(t *testing.T) { + testCases := []struct { + name string + kind state.Kind + expectKeep bool + }{ + {name: "header", kind: state.KindHeader, expectKeep: true}, + {name: "query", kind: state.KindQuery, expectKeep: true}, + {name: "form", kind: state.KindForm, expectKeep: true}, + {name: "body skipped in parameter list", kind: state.KindRequestBody, expectKeep: false}, + {name: "path skipped", kind: state.KindPath, expectKeep: false}, + {name: "cookie skipped", kind: state.KindCookie, expectKeep: false}, + {name: "state skipped", kind: state.KindState, expectKeep: false}, + } + + for _, tc := range testCases { + tc := tc + t.Run(tc.name, func(t *testing.T) { + g := &generator{ + _parametersIndex: map[string]*openapi3.Parameter{}, + commonParameters: map[string]*openapi3.Parameter{}, + } + comp := newTestComponent(t) + comp.View = &view.View{Template: &view.Template{}, Selector: &view.Config{}} + cSchema := &ComponentSchema{component: comp, schemas: NewContainer()} + param := &state.Parameter{ + Name: "ID", + In: &state.Location{Kind: tc.kind, Name: "id"}, + Schema: state.NewSchema(reflect.TypeOf(1)), + } + + converted, ok, err := g.convertParam(context.Background(), cSchema, param, "") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if ok != tc.expectKeep { + t.Fatalf("expected keep=%v, got %v", tc.expectKeep, ok) + } + if tc.expectKeep && len(converted) != 1 { + t.Fatalf("expected one converted parameter, got %d", len(converted)) + } + if !tc.expectKeep && len(converted) != 0 { + t.Fatalf("expected no converted parameters, got %d", len(converted)) + } + }) + } + }) + + t.Run("convert param object and non-http", func(t *testing.T) { + g := &generator{ + _parametersIndex: map[string]*openapi3.Parameter{}, + commonParameters: map[string]*openapi3.Parameter{}, + } + comp := newTestComponent(t) + comp.View = &view.View{Template: &view.Template{}, Selector: &view.Config{}} + cSchema := &ComponentSchema{component: comp, schemas: NewContainer()} + + objectParam := &state.Parameter{ + Name: "Obj", + In: &state.Location{Kind: state.KindObject}, + Object: state.Parameters{ + {Name: "A", In: state.NewQueryLocation("a"), Schema: state.NewSchema(reflect.TypeOf(""))}, + }, + } + converted, ok, err := g.convertParam(context.Background(), cSchema, objectParam, "") + if err != nil || !ok || len(converted) != 1 { + t.Fatalf("unexpected object convert: %v %v %d", ok, err, len(converted)) + } + + nonHTTP := &state.Parameter{Name: "S", In: &state.Location{Kind: state.KindState, Name: "state"}, Schema: state.NewSchema(reflect.TypeOf(""))} + converted, ok, err = g.convertParam(context.Background(), cSchema, nonHTTP, "") + if err != nil || ok || len(converted) != 0 { + t.Fatalf("unexpected non-http convert: %v %v %d", ok, err, len(converted)) + } + }) + + t.Run("convert param via kind param and cache ref", func(t *testing.T) { + g := &generator{ + _parametersIndex: map[string]*openapi3.Parameter{}, + commonParameters: map[string]*openapi3.Parameter{}, + } + comp := newTestComponent(t) + comp.View = &view.View{Template: &view.Template{}, Selector: &view.Config{}} + base := &state.Parameter{Name: "ID", In: state.NewQueryLocation("id"), Schema: state.NewSchema(reflect.TypeOf(1))} + comp.Input.Type.Parameters = state.Parameters{base} + cSchema := &ComponentSchema{component: comp, schemas: NewContainer()} + + refParam := &state.Parameter{Name: "Ref", In: &state.Location{Kind: state.KindParam, Name: "ID"}, Schema: state.NewSchema(reflect.TypeOf(1))} + converted, ok, err := g.convertParam(context.Background(), cSchema, refParam, "") + if err != nil || !ok || len(converted) != 1 { + t.Fatalf("unexpected kind-param convert: %v %v %d", ok, err, len(converted)) + } + + converted, ok, err = g.convertParam(context.Background(), cSchema, base, "") + if err != nil || !ok || len(converted) != 1 { + t.Fatalf("unexpected cache convert: %v %v %#v", ok, err, converted) + } + }) + + t.Run("append built-in and view params", func(t *testing.T) { + g := &generator{ + _parametersIndex: map[string]*openapi3.Parameter{}, + commonParameters: map[string]*openapi3.Parameter{}, + } + comp := newTestComponent(t) + comp.View = &view.View{Template: &view.Template{}, Selector: &view.Config{}} + cSchema := &ComponentSchema{component: comp, schemas: NewContainer()} + params := []*openapi3.Parameter{} + param := &state.Parameter{Name: "Limit", In: state.NewQueryLocation("limit"), Schema: state.NewSchema(reflect.TypeOf(1))} + if err := g.appendBuiltInParam(context.Background(), ¶ms, cSchema, param); err != nil { + t.Fatalf("unexpected append error: %v", err) + } + if len(params) == 0 { + t.Fatalf("expected built-in param") + } + }) + + t.Run("view parameters with selector built-ins", func(t *testing.T) { + g := &generator{ + _parametersIndex: map[string]*openapi3.Parameter{}, + commonParameters: map[string]*openapi3.Parameter{}, + } + comp := newTestComponent(t) + comp.View = &view.View{Template: &view.Template{}, Selector: &view.Config{}} + cSchema := &ComponentSchema{component: comp, schemas: NewContainer()} + v := &view.View{ + Template: &view.Template{ + Parameters: state.Parameters{ + {Name: "Q", In: state.NewQueryLocation("q"), Schema: state.NewSchema(reflect.TypeOf(""))}, + }, + }, + Selector: &view.Config{ + CriteriaParameter: &state.Parameter{Name: "Criteria", In: state.NewQueryLocation("_criteria"), Schema: state.NewSchema(reflect.TypeOf(""))}, + LimitParameter: &state.Parameter{Name: "Limit", In: state.NewQueryLocation("_limit"), Schema: state.NewSchema(reflect.TypeOf(1))}, + OffsetParameter: &state.Parameter{Name: "Offset", In: state.NewQueryLocation("_offset"), Schema: state.NewSchema(reflect.TypeOf(1))}, + PageParameter: &state.Parameter{Name: "Page", In: state.NewQueryLocation("_page"), Schema: state.NewSchema(reflect.TypeOf(1))}, + OrderByParameter: &state.Parameter{Name: "OrderBy", In: state.NewQueryLocation("_orderby"), Schema: state.NewSchema(reflect.TypeOf([]string{}))}, + FieldsParameter: &state.Parameter{Name: "Fields", In: state.NewQueryLocation("_fields"), Schema: state.NewSchema(reflect.TypeOf([]string{}))}, + }, + } + params, err := g.viewParameters(context.Background(), v, cSchema) + if err != nil { + t.Fatalf("unexpected viewParameters error: %v", err) + } + if len(params) < 7 { + t.Fatalf("expected builtin and template params, got %d", len(params)) + } + + v.Template.Parameters = append(v.Template.Parameters, &state.Parameter{Name: "StateParam", In: &state.Location{Kind: state.KindState, Name: "s"}, Schema: state.NewSchema(reflect.TypeOf(""))}) + params, err = g.viewParameters(context.Background(), v, cSchema) + if err != nil { + t.Fatalf("unexpected viewParameters error: %v", err) + } + if len(params) < 7 { + t.Fatalf("expected params with non-http skipped, got %d", len(params)) + } + }) + + t.Run("generate operation happy path", func(t *testing.T) { + g := &generator{ + _parametersIndex: map[string]*openapi3.Parameter{}, + commonParameters: map[string]*openapi3.Parameter{}, + } + comp := newTestComponent(t) + comp.Method = http.MethodPost + comp.Path.Method = http.MethodPost + comp.View = &view.View{Template: &view.Template{}, Selector: &view.Config{}} + comp.Input.Body = state.Type{Schema: state.NewSchema(reflect.TypeOf(struct{ Name string }{}))} + comp.Input.Type = state.Type{Schema: state.NewSchema(reflect.TypeOf(struct{ Name string }{}))} + comp.Output.Type = state.Type{Schema: state.NewSchema(reflect.TypeOf(struct{ ID int }{}))} + cSchema := &ComponentSchema{component: comp, schemas: NewContainer()} + operation, err := g.generateOperation(context.Background(), cSchema) + if err != nil || operation == nil { + t.Fatalf("unexpected operation result: %v %v", operation, err) + } + if _, ok := openapi3.GetResponse(operation.Responses, openapi3.ResponseOK); !ok { + t.Fatalf("expected 200 response") + } + }) + + t.Run("generate operation with component parameter", func(t *testing.T) { + g := &generator{ + _parametersIndex: map[string]*openapi3.Parameter{}, + commonParameters: map[string]*openapi3.Parameter{}, + } + + components := &repository.Service{} + registry := repository.NewRegistry("", nil, nil) + setUnexportedField(components, "registry", registry) + + dep := newTestComponent(t) + dep.Method = http.MethodGet + dep.Path.Method = http.MethodGet + dep.Path.URI = "/v1/dep" + dep.View = &view.View{ + Template: &view.Template{ + Parameters: state.Parameters{ + {Name: "DepID", In: state.NewQueryLocation("depId"), Schema: state.NewSchema(reflect.TypeOf(1))}, + }, + }, + Selector: &view.Config{}, + } + dep.Output.Type = state.Type{Schema: state.NewSchema(reflect.TypeOf(struct{ ID int }{}))} + registry.Register(dep) + + comp := newTestComponent(t) + comp.Method = http.MethodPost + comp.Path.Method = http.MethodPost + comp.View = &view.View{Template: &view.Template{}, Selector: &view.Config{}} + comp.Input.Body = state.Type{Schema: state.NewSchema(reflect.TypeOf(struct{ Name string }{}))} + comp.Input.Type = state.Type{Schema: state.NewSchema(reflect.TypeOf(struct{ Name string }{}))} + comp.Output.Type = state.Type{ + Schema: state.NewSchema(reflect.TypeOf(struct{ ID int }{})), + Parameters: state.Parameters{ + {Name: "Dep", In: &state.Location{Kind: state.KindComponent, Name: "GET:/v1/dep"}}, + }, + } + + cSchema := &ComponentSchema{component: comp, components: components, schemas: NewContainer()} + operation, err := g.generateOperation(context.Background(), cSchema) + if err != nil { + t.Fatalf("unexpected operation error: %v", err) + } + if operation == nil || len(operation.Parameters) == 0 { + t.Fatalf("expected operation with merged parameters") + } + }) + + t.Run("generate operation request body error", func(t *testing.T) { + g := &generator{ + _parametersIndex: map[string]*openapi3.Parameter{}, + commonParameters: map[string]*openapi3.Parameter{}, + } + comp := newTestComponent(t) + comp.Method = http.MethodPost + comp.Path.Method = http.MethodPost + comp.View = &view.View{Template: &view.Template{}, Selector: &view.Config{}} + comp.Input.Body = state.Type{Schema: state.NewSchema(reflect.TypeOf((chan int)(nil)))} + comp.Input.Type = state.Type{Schema: state.NewSchema(reflect.TypeOf((chan int)(nil)))} + comp.Output.Type = state.Type{Schema: state.NewSchema(reflect.TypeOf(struct{ ID int }{}))} + cSchema := &ComponentSchema{component: comp, schemas: NewContainer()} + if _, err := g.generateOperation(context.Background(), cSchema); err == nil { + t.Fatalf("expected request body generation error") + } + }) + + t.Run("generate operation response error", func(t *testing.T) { + g := &generator{ + _parametersIndex: map[string]*openapi3.Parameter{}, + commonParameters: map[string]*openapi3.Parameter{}, + } + comp := newTestComponent(t) + comp.Method = http.MethodGet + comp.Path.Method = http.MethodGet + comp.View = &view.View{Template: &view.Template{}, Selector: &view.Config{}} + comp.Output.Type = state.Type{Schema: state.NewSchema(reflect.TypeOf((chan int)(nil)))} + cSchema := &ComponentSchema{component: comp, schemas: NewContainer()} + if _, err := g.generateOperation(context.Background(), cSchema); err == nil { + t.Fatalf("expected response generation error") + } + }) + + t.Run("generate paths with providers", func(t *testing.T) { + g := &generator{ + _parametersIndex: map[string]*openapi3.Parameter{}, + commonParameters: map[string]*openapi3.Parameter{}, + } + control := &version.Control{} + comp1 := newTestComponent(t) + comp1.Method = http.MethodGet + comp1.Path.Method = http.MethodGet + comp1.Path.URI = "/v1/get" + comp1.View = &view.View{Template: &view.Template{}, Selector: &view.Config{}} + comp1.Output.Type = state.Type{Schema: state.NewSchema(reflect.TypeOf(struct{ ID int }{}))} + + comp2 := newTestComponent(t) + comp2.Method = http.MethodPost + comp2.Path.Method = http.MethodPost + comp2.Path.URI = "/v1/post" + comp2.View = &view.View{Template: &view.Template{}, Selector: &view.Config{}} + comp2.Input.Body = state.Type{Schema: state.NewSchema(reflect.TypeOf(struct{ Name string }{}))} + comp2.Input.Type = state.Type{Schema: state.NewSchema(reflect.TypeOf(struct{ Name string }{}))} + comp2.Output.Type = state.Type{Schema: state.NewSchema(reflect.TypeOf(struct{ ID int }{}))} + + provider1 := repository.NewProvider(comp1.Path, control, func(ctx context.Context, opts ...repository.Option) (*repository.Component, error) { return comp1, nil }) + provider2 := repository.NewProvider(comp2.Path, control, func(ctx context.Context, opts ...repository.Option) (*repository.Component, error) { return comp2, nil }) + + _, paths, err := g.generatePaths(context.Background(), &repository.Service{}, []*repository.Provider{provider1, provider2}) + if err != nil { + t.Fatalf("unexpected generate paths error: %v", err) + } + if paths["/v1/get"] == nil || paths["/v1/post"] == nil { + t.Fatalf("expected generated paths") + } + if paths["/v1/get"].Get == nil || paths["/v1/get"].Post != nil { + t.Fatalf("expected isolated GET path item") + } + if paths["/v1/post"].Post == nil || paths["/v1/post"].Get != nil { + t.Fatalf("expected isolated POST path item") + } + }) + + t.Run("generate paths with all methods and provider errors", func(t *testing.T) { + g := &generator{ + _parametersIndex: map[string]*openapi3.Parameter{}, + commonParameters: map[string]*openapi3.Parameter{}, + } + control := &version.Control{} + + mk := func(method, uri string) *repository.Provider { + comp := newTestComponent(t) + comp.Method = method + comp.Path.Method = method + comp.Path.URI = uri + comp.View = &view.View{Template: &view.Template{}, Selector: &view.Config{}} + comp.Output.Type = state.Type{Schema: state.NewSchema(reflect.TypeOf(struct{ ID int }{}))} + if method != http.MethodGet { + comp.Input.Body = state.Type{Schema: state.NewSchema(reflect.TypeOf(struct{ Name string }{}))} + comp.Input.Type = state.Type{Schema: state.NewSchema(reflect.TypeOf(struct{ Name string }{}))} + } + return repository.NewProvider(comp.Path, control, func(ctx context.Context, opts ...repository.Option) (*repository.Component, error) { return comp, nil }) + } + + errProvider := repository.NewProvider(contract.Path{Method: http.MethodGet, URI: "/v1/error"}, control, func(ctx context.Context, opts ...repository.Option) (*repository.Component, error) { + return nil, errors.New("provider error") + }) + + controlDeleted := &version.Control{} + controlDeleted.SetChangeKind(version.ChangeKindDeleted) + nilProvider := repository.NewProvider(contract.Path{Method: http.MethodGet, URI: "/v1/nil"}, controlDeleted, func(ctx context.Context, opts ...repository.Option) (*repository.Component, error) { + return nil, nil + }) + + providers := []*repository.Provider{ + mk(http.MethodDelete, "/v1/delete"), + mk(http.MethodPut, "/v1/put"), + mk(http.MethodPatch, "/v1/patch"), + errProvider, + nilProvider, + } + _, paths, err := g.generatePaths(context.Background(), &repository.Service{}, providers) + if err == nil { + t.Fatalf("expected provider error") + } + if paths["/v1/delete"] == nil || paths["/v1/put"] == nil || paths["/v1/patch"] == nil { + t.Fatalf("expected generated method paths") + } + }) + + t.Run("operation parameters include component output params", func(t *testing.T) { + g := &generator{ + _parametersIndex: map[string]*openapi3.Parameter{}, + commonParameters: map[string]*openapi3.Parameter{}, + } + + components := &repository.Service{} + registry := repository.NewRegistry("", nil, nil) + setUnexportedField(components, "registry", registry) + + dep := newTestComponent(t) + dep.Method = http.MethodGet + dep.Path.Method = http.MethodGet + dep.Path.URI = "/v1/opdep" + dep.View = &view.View{ + Template: &view.Template{ + Parameters: state.Parameters{ + {Name: "DepID", In: state.NewQueryLocation("depId"), Schema: state.NewSchema(reflect.TypeOf(1))}, + }, + }, + Selector: &view.Config{}, + } + dep.Output.Type = state.Type{Schema: state.NewSchema(reflect.TypeOf(struct{ ID int }{}))} + registry.Register(dep) + + comp := newTestComponent(t) + comp.View = &view.View{ + Template: &view.Template{ + Parameters: state.Parameters{ + {Name: "RootQ", In: state.NewQueryLocation("q"), Schema: state.NewSchema(reflect.TypeOf(""))}, + }, + }, + Selector: &view.Config{}, + } + comp.Output.Type = state.Type{ + Schema: state.NewSchema(reflect.TypeOf(struct{ ID int }{})), + Parameters: state.Parameters{ + {Name: "Dep", In: &state.Location{Kind: state.KindComponent, Name: "GET:/v1/opdep"}}, + }, + } + + cSchema := &ComponentSchema{component: comp, components: components, schemas: NewContainer()} + params, err := g.operationParameters(context.Background(), cSchema) + if err != nil { + t.Fatalf("unexpected operationParameters error: %v", err) + } + if len(params) < 2 { + t.Fatalf("expected root and component params, got %d", len(params)) + } + }) + + t.Run("component output parameters no component refs", func(t *testing.T) { + g := &generator{ + _parametersIndex: map[string]*openapi3.Parameter{}, + commonParameters: map[string]*openapi3.Parameter{}, + } + comp := newTestComponent(t) + comp.Output.Type = state.Type{ + Schema: state.NewSchema(reflect.TypeOf(struct{ ID int }{})), + Parameters: state.Parameters{{Name: "OnlyState", In: &state.Location{Kind: state.KindState, Name: "s"}}}, + } + cSchema := &ComponentSchema{component: comp, schemas: NewContainer()} + params, err := g.componentOutputParameters(context.Background(), cSchema) + if err != nil { + t.Fatalf("unexpected componentOutputParameters error: %v", err) + } + if len(params) != 0 { + t.Fatalf("expected no component params, got %d", len(params)) + } + }) + + t.Run("lookup component param error", func(t *testing.T) { + g := &generator{} + components := &repository.Service{} + registry := repository.NewRegistry("", nil, nil) + setUnexportedField(components, "registry", registry) + dep := newTestComponent(t) + dep.Method = http.MethodGet + dep.Path.Method = http.MethodGet + dep.Path.URI = "/v1/existing" + registry.Register(dep) + comp := newTestComponent(t) + cSchema := &ComponentSchema{component: comp, components: components, schemas: NewContainer()} + if _, err := g.lookupComponentParam(context.Background(), cSchema, "GET:/v1/missing"); err == nil { + t.Fatalf("expected missing provider error") + } + }) + + t.Run("operation parameters missing component provider", func(t *testing.T) { + g := &generator{ + _parametersIndex: map[string]*openapi3.Parameter{}, + commonParameters: map[string]*openapi3.Parameter{}, + } + components := &repository.Service{} + registry := repository.NewRegistry("", nil, nil) + setUnexportedField(components, "registry", registry) + existing := newTestComponent(t) + existing.Method = http.MethodGet + existing.Path.Method = http.MethodGet + existing.Path.URI = "/v1/existing" + registry.Register(existing) + + comp := newTestComponent(t) + comp.View = &view.View{Template: &view.Template{}, Selector: &view.Config{}} + comp.Output.Type = state.Type{ + Schema: state.NewSchema(reflect.TypeOf(struct{ ID int }{})), + Parameters: state.Parameters{ + {Name: "MissingDep", In: &state.Location{Kind: state.KindComponent, Name: "GET:/v1/unknown"}}, + }, + } + cSchema := &ComponentSchema{component: comp, components: components, schemas: NewContainer()} + if _, err := g.operationParameters(context.Background(), cSchema); err == nil { + t.Fatalf("expected missing dependency error") + } + }) + + t.Run("convert param cache nil ref", func(t *testing.T) { + g := &generator{ + _parametersIndex: map[string]*openapi3.Parameter{}, + commonParameters: map[string]*openapi3.Parameter{}, + } + comp := newTestComponent(t) + comp.View = &view.View{Template: &view.Template{}, Selector: &view.Config{}} + cSchema := &ComponentSchema{component: comp, schemas: NewContainer()} + param := &state.Parameter{Name: "ID", In: &state.Location{Kind: state.KindQuery}, Schema: state.NewSchema(reflect.TypeOf(1))} + + first, ok, err := g.convertParam(context.Background(), cSchema, param, "") + if err != nil || !ok || len(first) != 1 { + t.Fatalf("unexpected first convert result: %v %v %d", ok, err, len(first)) + } + + second, ok, err := g.convertParam(context.Background(), cSchema, param, "") + if err != nil || !ok || len(second) != 1 { + t.Fatalf("unexpected second convert result: %v %v %d", ok, err, len(second)) + } + if second[0].Ref == "" { + t.Fatalf("expected parameter ref") + } + + third, ok, err := g.convertParam(context.Background(), cSchema, param, "") + if err != nil || !ok || len(third) != 1 { + t.Fatalf("unexpected third convert result: %v %v %d", ok, err, len(third)) + } + if third[0].Ref == "" { + t.Fatalf("expected cached nil path to still return ref") + } + }) + + t.Run("convert param kind param skips component reference", func(t *testing.T) { + g := &generator{ + _parametersIndex: map[string]*openapi3.Parameter{}, + commonParameters: map[string]*openapi3.Parameter{}, + } + comp := newTestComponent(t) + comp.View = &view.View{Template: &view.Template{}, Selector: &view.Config{}} + base := &state.Parameter{Name: "Auth", In: &state.Location{Kind: state.KindComponent, Name: "GET:/v1/auth"}, Schema: state.NewSchema(reflect.TypeOf(struct{ A int }{}))} + comp.Input.Type.Parameters = state.Parameters{base} + cSchema := &ComponentSchema{component: comp, schemas: NewContainer()} + + refParam := &state.Parameter{Name: "AuthRef", In: &state.Location{Kind: state.KindParam, Name: "Auth"}, Schema: state.NewSchema(reflect.TypeOf(struct{ A int }{}))} + converted, ok, err := g.convertParam(context.Background(), cSchema, refParam, "") + if err != nil || ok || len(converted) != 0 { + t.Fatalf("expected component kind param reference to be skipped, got ok=%v err=%v len=%d", ok, err, len(converted)) + } + }) + + t.Run("convert param kind param body reference skipped from parameters", func(t *testing.T) { + g := &generator{ + _parametersIndex: map[string]*openapi3.Parameter{}, + commonParameters: map[string]*openapi3.Parameter{}, + } + comp := newTestComponent(t) + comp.View = &view.View{Template: &view.Template{}, Selector: &view.Config{}} + base := &state.Parameter{Name: "BodyInput", In: &state.Location{Kind: state.KindRequestBody}, Schema: state.NewSchema(reflect.TypeOf(struct{ A int }{}))} + comp.Input.Type.Parameters = state.Parameters{base} + cSchema := &ComponentSchema{component: comp, schemas: NewContainer()} + + refParam := &state.Parameter{Name: "BodyRef", In: &state.Location{Kind: state.KindParam, Name: "BodyInput"}, Schema: state.NewSchema(reflect.TypeOf(struct{ A int }{}))} + converted, ok, err := g.convertParam(context.Background(), cSchema, refParam, "") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if ok || len(converted) != 0 { + t.Fatalf("expected body-derived kind=param to be skipped from parameter list, got ok=%v len=%d", ok, len(converted)) + } + }) + + t.Run("append built-in non-http parameter", func(t *testing.T) { + g := &generator{ + _parametersIndex: map[string]*openapi3.Parameter{}, + commonParameters: map[string]*openapi3.Parameter{}, + } + comp := newTestComponent(t) + comp.View = &view.View{Template: &view.Template{}, Selector: &view.Config{}} + cSchema := &ComponentSchema{component: comp, schemas: NewContainer()} + params := []*openapi3.Parameter{} + stateParam := &state.Parameter{Name: "StateOnly", In: &state.Location{Kind: state.KindState, Name: "state"}, Schema: state.NewSchema(reflect.TypeOf(""))} + if err := g.appendBuiltInParam(context.Background(), ¶ms, cSchema, stateParam); err != nil { + t.Fatalf("unexpected append error: %v", err) + } + if len(params) != 0 { + t.Fatalf("expected non-http built-in param to be skipped") + } + }) + + t.Run("view parameters and relation errors", func(t *testing.T) { + g := &generator{ + _parametersIndex: map[string]*openapi3.Parameter{}, + commonParameters: map[string]*openapi3.Parameter{}, + } + comp := newTestComponent(t) + comp.View = &view.View{Template: &view.Template{}, Selector: &view.Config{}} + cSchema := &ComponentSchema{component: comp, schemas: NewContainer()} + + errorView := &view.View{ + Template: &view.Template{ + Parameters: state.Parameters{ + {Name: "Bad", In: state.NewQueryLocation("bad"), Schema: state.NewSchema(reflect.TypeOf((chan int)(nil)))}, + }, + }, + Selector: &view.Config{}, + } + if _, err := g.viewParameters(context.Background(), errorView, cSchema); err == nil { + t.Fatalf("expected view parameter conversion error") + } + + relationErrorView := &view.View{ + Template: &view.Template{}, + Selector: &view.Config{}, + With: []*view.Relation{ + {Of: &view.ReferenceView{View: *errorView}}, + }, + } + if _, err := g.getAllViewsParameters(context.Background(), cSchema, relationErrorView); err == nil { + t.Fatalf("expected relation parameter conversion error") + } + }) +} diff --git a/gateway/router/openapi/helpers_test.go b/gateway/router/openapi/helpers_test.go new file mode 100644 index 000000000..44d4c56ef --- /dev/null +++ b/gateway/router/openapi/helpers_test.go @@ -0,0 +1,42 @@ +package openapi + +import ( + "testing" + + openapi3 "github.com/viant/datly/gateway/router/openapi/openapi3" +) + +func TestDedupe(t *testing.T) { + tests := []struct { + name string + in []*openapi3.Parameter + expectNames []string + }{ + { + name: "dedupes by name and location", + in: []*openapi3.Parameter{ + {Name: "id", In: "query"}, + {Name: "id", In: "query"}, + {Name: "id", In: "path"}, + {Name: "limit", In: "query"}, + }, + expectNames: []string{"id:query", "id:path", "limit:query"}, + }, + {name: "empty", in: nil, expectNames: nil}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + out := dedupe(tt.in) + if len(out) != len(tt.expectNames) { + t.Fatalf("expected len %d, got %d", len(tt.expectNames), len(out)) + } + for i := range out { + actual := out[i].Name + ":" + out[i].In + if actual != tt.expectNames[i] { + t.Fatalf("at %d expected %q, got %q", i, tt.expectNames[i], actual) + } + } + }) + } +} diff --git a/gateway/router/openapi/logic_test.go b/gateway/router/openapi/logic_test.go new file mode 100644 index 000000000..9c3f92d39 --- /dev/null +++ b/gateway/router/openapi/logic_test.go @@ -0,0 +1,333 @@ +package openapi + +import ( + "context" + "errors" + "reflect" + "testing" + "time" + "unsafe" + + openapi3 "github.com/viant/datly/gateway/router/openapi/openapi3" + "github.com/viant/datly/repository" + "github.com/viant/datly/repository/contract" + "github.com/viant/datly/view" + "github.com/viant/datly/view/state" + "github.com/viant/datly/view/tags" + "github.com/viant/tagly/format" + "github.com/viant/xreflect" +) + +type fakeDocService struct { + lookup func(key string) (string, bool, error) +} + +func (f *fakeDocService) Lookup(ctx context.Context, key string) (string, bool, error) { + if f.lookup == nil { + return "", false, nil + } + return f.lookup(key) +} + +func setUnexportedField(target interface{}, fieldName string, value interface{}) { + v := reflect.ValueOf(target).Elem().FieldByName(fieldName) + reflect.NewAt(v.Type(), unsafe.Pointer(v.UnsafeAddr())).Elem().Set(reflect.ValueOf(value)) +} + +func newTestComponent(t *testing.T) *repository.Component { + t.Helper() + component, err := repository.NewComponent(&contract.Path{Method: "POST", URI: "/v1/test"}, repository.WithView(&view.View{Template: &view.Template{}, Selector: &view.Config{}})) + if err != nil { + t.Fatalf("failed to create component: %v", err) + } + types := xreflect.NewTypes() + setUnexportedField(component, "types", types) + return component +} + +func TestPathsBuilderAddPath(t *testing.T) { + builder := &PathsBuilder{paths: openapi3.Paths{}} + item := &openapi3.PathItem{Summary: "sum"} + builder.AddPath("/v1/pets", item) + if builder.paths["/v1/pets"] != item { + t.Fatalf("path not added") + } +} + +func TestGeneratorHelpers_Table(t *testing.T) { + t.Run("forEachParam recursive and error", func(t *testing.T) { + g := &generator{} + called := 0 + params := state.Parameters{ + {Name: "root", Object: state.Parameters{{Name: "child1"}}, Repeated: state.Parameters{{Name: "child2"}}}, + } + err := g.forEachParam(params, func(parameter *state.Parameter) (bool, error) { + called++ + if parameter.Name == "child1" { + return true, errors.New("boom") + } + return true, nil + }) + if err == nil || err.Error() != "boom" { + t.Fatalf("expected boom, got %v", err) + } + if called < 2 { + t.Fatalf("expected recursive traversal") + } + }) + + t.Run("index parameters", func(t *testing.T) { + g := &generator{} + params := []*openapi3.Parameter{{Name: "a"}, {Name: "b"}} + indexed := g.indexParameters(params) + if indexed["a"].Name != "a" || indexed["b"].Name != "b" { + t.Fatalf("unexpected indexed values") + } + }) + + t.Run("string ptr", func(t *testing.T) { + if *stringPtr("x") != "x" { + t.Fatalf("unexpected value") + } + }) +} + +func TestComponentSchemaHelpers_Table(t *testing.T) { + component := newTestComponent(t) + componentSchema := &ComponentSchema{component: component, schemas: NewContainer()} + + t.Run("isRequired", func(t *testing.T) { + req := true + in := contract.Input{Body: state.Type{Parameters: state.Parameters{{Required: &req}}}} + if !componentSchema.isRequired(in) { + t.Fatalf("expected required") + } + }) + + t.Run("description and example defaults", func(t *testing.T) { + desc, err := componentSchema.Description(context.Background(), "A", "default-desc") + if err != nil || desc != "default-desc" { + t.Fatalf("unexpected result: %q %v", desc, err) + } + example, err := componentSchema.Example(context.Background(), "A", "default-ex") + if err != nil || example != "default-ex" { + t.Fatalf("unexpected result: %q %v", example, err) + } + }) + + t.Run("description and example from doc", func(t *testing.T) { + componentSchema.doc = &fakeDocService{lookup: func(key string) (string, bool, error) { + switch key { + case "A": + return "desc", true, nil + case "A$example": + return "ex", true, nil + default: + return "", false, nil + } + }} + desc, err := componentSchema.Description(context.Background(), "A", "default-desc") + if err != nil || desc != "desc" { + t.Fatalf("unexpected description: %q %v", desc, err) + } + example, err := componentSchema.Example(context.Background(), "A", "default-ex") + if err != nil || example != "ex" { + t.Fatalf("unexpected example: %q %v", example, err) + } + }) + + t.Run("description error", func(t *testing.T) { + componentSchema.doc = &fakeDocService{lookup: func(key string) (string, bool, error) { + return "", false, errors.New("lookup") + }} + if _, err := componentSchema.Description(context.Background(), "A", "default"); err == nil { + t.Fatalf("expected error") + } + }) + + t.Run("typed/request/response schema", func(t *testing.T) { + componentSchema.doc = nil + component.Input.Type = state.Type{Schema: state.NewSchema(reflect.TypeOf(struct{ Name string }{}))} + component.Output.Type = state.Type{Schema: state.NewSchema(reflect.TypeOf(struct{ ID int }{}))} + component.Input.Body = state.Type{Schema: state.NewSchema(reflect.TypeOf(struct{ Name string }{}))} + + reqSchema, err := componentSchema.RequestBody(context.Background()) + if err != nil || reqSchema == nil { + t.Fatalf("unexpected request schema result: %v %v", reqSchema, err) + } + + respSchema, err := componentSchema.ResponseBody(context.Background()) + if err != nil || respSchema == nil { + t.Fatalf("unexpected response schema result: %v %v", respSchema, err) + } + + if _, err = componentSchema.TypedSchema(context.Background(), component.Input.Type, "Input", component.IOConfig(), true); err != nil { + t.Fatalf("unexpected typed schema error: %v", err) + } + }) + + t.Run("type name and schema helpers", func(t *testing.T) { + type sample struct{} + types := xreflect.NewTypes() + if err := types.Register("Sample", xreflect.WithPackage("test"), xreflect.WithReflectType(reflect.TypeOf(sample{}))); err != nil { + t.Fatalf("register type failed: %v", err) + } + setUnexportedField(component, "types", types) + + if got := componentSchema.TypeName(reflect.TypeOf(sample{}), "fallback"); got != "Sample" { + t.Fatalf("expected Sample, got %s", got) + } + + refl := componentSchema.ReflectSchema("A", reflect.TypeOf(sample{}), "d", component.IOConfig()) + if refl == nil || refl.rType != reflect.TypeOf(sample{}) { + t.Fatalf("unexpected reflect schema") + } + + withTag := componentSchema.SchemaWithTag("F", reflect.TypeOf(sample{}), "d", component.IOConfig(), Tag{}) + if withTag == nil || withTag.path == "" { + t.Fatalf("unexpected schema with tag") + } + }) + + t.Run("schema with tag datatype override", func(t *testing.T) { + type alt struct{ Value string } + reg := xreflect.NewTypes() + if err := reg.Register("Alt", xreflect.WithReflectType(reflect.TypeOf(alt{}))); err != nil { + t.Fatalf("register type failed: %v", err) + } + if component.View.GetResource() == nil { + component.View.SetResource(&view.Resource{}) + } + component.View.GetResource().SetTypes(reg) + withTag := componentSchema.SchemaWithTag("F", reflect.TypeOf(struct{ A int }{}), "d", component.IOConfig(), Tag{ + IsInput: true, + Parameter: &tags.Parameter{DataType: "Alt"}, + }) + if withTag.rType != reflect.TypeOf(alt{}) { + t.Fatalf("expected datatype override") + } + }) + + t.Run("schema with tag primitive datatype override", func(t *testing.T) { + withTag := componentSchema.SchemaWithTag("Jwt", reflect.TypeOf(struct{ A int }{}), "d", component.IOConfig(), Tag{ + IsInput: true, + Parameter: &tags.Parameter{DataType: "string"}, + }) + if withTag.rType != reflect.TypeOf("") { + t.Fatalf("expected primitive datatype override to string, got %v", withTag.rType) + } + }) + + t.Run("schema with tag output keeps go type", func(t *testing.T) { + goType := reflect.TypeOf(struct{ A int }{}) + withTag := componentSchema.SchemaWithTag("Out", goType, "d", component.IOConfig(), Tag{ + Parameter: &tags.Parameter{ + DataType: "string", + }, + }) + if withTag.rType != goType { + t.Fatalf("expected output kind to keep go type, got %v", withTag.rType) + } + }) +} + +func TestSchemaContainerCreateSchema_Table(t *testing.T) { + container := NewContainer() + componentSchema := &ComponentSchema{component: newTestComponent(t), schemas: container} + fieldSchema := &Schema{path: "p", description: "d", rType: reflect.TypeOf(1)} + + t.Run("create primitive", func(t *testing.T) { + result, err := container.CreateSchema(context.Background(), componentSchema, fieldSchema) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if result.Type != integerOutput { + t.Fatalf("unexpected type: %s", result.Type) + } + }) + + t.Run("get or generate delegates", func(t *testing.T) { + result, err := componentSchema.GetOrGenerateSchema(context.Background(), fieldSchema) + if err != nil || result.Type != integerOutput { + t.Fatalf("unexpected result: %v %v", result, err) + } + }) + + t.Run("create cached ref", func(t *testing.T) { + container.generatedSchemas["Cached"] = &openapi3.Schema{Type: objectOutput} + cached, err := container.createSchema(context.Background(), componentSchema, &Schema{path: "p", description: "d", rType: reflect.TypeOf(struct{}{}), tag: Tag{TypeName: "Cached"}}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if cached.Ref != "#/components/schemas/Cached" { + t.Fatalf("unexpected ref: %s", cached.Ref) + } + }) + + t.Run("create struct and generate schema", func(t *testing.T) { + type rec struct { + ID int `json:"id"` + } + sch, err := container.createSchema(context.Background(), componentSchema, &Schema{ + path: "rec", + description: "record", + rType: reflect.TypeOf(rec{}), + tag: Tag{TypeName: "Rec"}, + ioConfig: componentSchema.component.IOConfig(), + }) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if sch.Ref == "" { + t.Fatalf("expected ref schema") + } + }) + + t.Run("addToSchema time format", func(t *testing.T) { + dst := &openapi3.Schema{} + err := container.addToSchema(context.Background(), componentSchema, dst, &Schema{ + rType: reflect.TypeOf(time.Time{}), + tag: Tag{_tag: format.Tag{TimeLayout: "2006-01-02"}}, + }) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if dst.Type != stringOutput || dst.Format != "date" { + t.Fatalf("unexpected time schema: %s %s", dst.Type, dst.Format) + } + }) + + t.Run("addToSchema struct filtering and inline", func(t *testing.T) { + type payload struct { + Visible string `json:"visible"` + Hidden string `json:"-"` + Internal string `internal:"true"` + Meta map[string]string `json:",inline"` + } + dst := &openapi3.Schema{} + err := container.addToSchema(context.Background(), componentSchema, dst, &Schema{ + rType: reflect.TypeOf(payload{}), + ioConfig: componentSchema.component.IOConfig(), + }) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if _, ok := dst.Properties["visible"]; !ok { + t.Fatalf("expected visible field") + } + if _, ok := dst.Properties["hidden"]; ok { + t.Fatalf("did not expect hidden field") + } + if _, ok := dst.Properties["internal"]; ok { + t.Fatalf("did not expect internal field") + } + }) +} + +func TestNewComponentSchema(t *testing.T) { + component := &repository.Component{} + got := NewComponentSchema(nil, component, nil) + if got == nil || got.schemas == nil { + t.Fatalf("expected initialized component schema") + } +} diff --git a/gateway/router/openapi/openapi3.go b/gateway/router/openapi/openapi3.go index 34ac163b8..a92a38894 100644 --- a/gateway/router/openapi/openapi3.go +++ b/gateway/router/openapi/openapi3.go @@ -5,7 +5,6 @@ import ( "fmt" openapi "github.com/viant/datly/gateway/router/openapi/openapi3" "github.com/viant/datly/repository" - "github.com/viant/datly/repository/contract" "github.com/viant/datly/shared" "github.com/viant/datly/view" "github.com/viant/datly/view/state" @@ -50,6 +49,24 @@ type ( } ) +func isRequestDerivedInputKind(kind state.Kind) bool { + switch kind { + case state.KindHeader, state.KindRequestBody, state.KindQuery, state.KindForm: + return true + default: + return false + } +} + +func isOpenAPIParameterKind(kind state.Kind) bool { + switch kind { + case state.KindHeader, state.KindQuery, state.KindForm: + return true + default: + return false + } +} + func (g *generator) GenerateSpec(ctx context.Context, repoComponents *repository.Service, info openapi.Info, providers ...*repository.Provider) (*openapi.OpenAPI, error) { components := &openapi.Components{} @@ -77,99 +94,6 @@ func GenerateOpenAPI3Spec(ctx context.Context, components *repository.Service, i }).GenerateSpec(ctx, components, info, providers...) } -func (g *generator) generatePaths(ctx context.Context, components *repository.Service, providers []*repository.Provider) (*SchemaContainer, openapi.Paths, error) { - container := NewContainer() - builder := &PathsBuilder{paths: openapi.Paths{}} - var retErr error - pathItem := &openapi.PathItem{} - for _, provider := range providers { - component, err := provider.Component(ctx) - if err != nil { - retErr = err - } - if component == nil { - fmt.Printf("provider.Component(ctx) returned nil\n") - continue - } - componentSchema := NewComponentSchema(components, component, container) - operation, err := g.generateOperation(ctx, componentSchema) - if err != nil { - retErr = err - } - switch component.Method { - case http.MethodGet: - pathItem.Get = operation - case http.MethodPost: - pathItem.Post = operation - case http.MethodDelete: - pathItem.Delete = operation - case http.MethodPut: - pathItem.Put = operation - case http.MethodPatch: - pathItem.Patch = operation - } - builder.AddPath(component.URI, pathItem) - } - - return container, builder.paths, retErr -} - -func (g *generator) generateOperation(ctx context.Context, component *ComponentSchema) (*openapi.Operation, error) { - body, err := g.requestBody(ctx, component) - if err != nil { - return nil, err - } - - parameters, err := g.getAllViewsParameters(ctx, component, component.component.View) - - if err != nil { - return nil, err - } - - if err := g.forEachParam(component.component.Output.Type.Parameters, func(parameter *state.Parameter) (bool, error) { - if parameter.In.Kind == state.KindComponent { - method, URI := shared.ExtractPath(parameter.In.Name) - provider, err := component.components.Registry().LookupProvider(ctx, &contract.Path{ - URI: URI, - Method: method, - }) - - if err != nil { - return false, err - } - - paramComponent, err := provider.Component(ctx) - if err != nil { - return false, err - } - - viewsParameters, err := g.getAllViewsParameters(ctx, NewComponentSchema(component.components, paramComponent, component.schemas), paramComponent.View) - if err != nil { - return false, err - } - - parameters = append(parameters, viewsParameters...) - } - - return true, nil - }); err != nil { - return nil, err - } - - responses, err := g.responses(ctx, component) - if err != nil { - return nil, err - } - - operation := &openapi.Operation{ - Parameters: dedupe(parameters), - RequestBody: body, - Responses: responses, - } - - return operation, nil -} - func dedupe(parameters []*openapi.Parameter) openapi.Parameters { index := map[paramLocation]bool{} var result []*openapi.Parameter @@ -282,6 +206,9 @@ func (g *generator) convertParam(ctx context.Context, component *ComponentSchema } if param.In.Kind == state.KindParam { baseParam := component.component.LookupParameter(param.In.Name) + if baseParam == nil || !isRequestDerivedInputKind(baseParam.In.Kind) { + return nil, false, nil + } return g.convertParam(ctx, component, baseParam, description) } @@ -301,7 +228,7 @@ func (g *generator) convertParam(ctx context.Context, component *ComponentSchema return result, true, nil } - if !param.IsHTTPParameter() { + if !isOpenAPIParameterKind(param.In.Kind) { return nil, false, nil } @@ -318,16 +245,21 @@ func (g *generator) convertParam(ctx context.Context, component *ComponentSchema } table := "" + var parameterTag *tags.Parameter if param.Tag != "" { - if datlyTags, _ := tags.Parse(reflect.StructTag(param.Tag), nil, tags.ViewTag); datlyTags != nil && datlyTags.View != nil { - table = datlyTags.View.Table + if datlyTags, _ := tags.Parse(reflect.StructTag(param.Tag), nil, tags.ViewTag, tags.ParameterTag); datlyTags != nil { + parameterTag = datlyTags.Parameter + if datlyTags.View != nil { + table = datlyTags.View.Table + } } - } schema, err := component.GenerateSchema(ctx, component.SchemaWithTag(param.Name, param.Schema.Type(), "Parameter "+param.Name+" schema", component.component.IOConfig(), Tag{ Format: param.DateFormat, IsNullable: !param.IsRequired(), Table: table, + Parameter: parameterTag, + IsInput: true, })) if err != nil { @@ -407,7 +339,7 @@ func (g *generator) requestBody(ctx context.Context, component *ComponentSchema) func (g *generator) responses(ctx context.Context, component *ComponentSchema) (openapi.Responses, error) { method := component.component.Method if method == http.MethodOptions { - return nil, nil + return openapi.Responses{}, nil } responseSchema, err := component.ResponseBody(ctx) @@ -421,27 +353,27 @@ func (g *generator) responses(ctx context.Context, component *ComponentSchema) ( } responses := openapi.Responses{} - responses[200] = &openapi.Response{ + openapi.SetResponse(responses, openapi.ResponseOK, &openapi.Response{ Description: stringPtr("Success response"), Content: map[string]*openapi.MediaType{ ApplicationJson: { Schema: schema, }, }, - } + }) errorSchema, err := component.GetOrGenerateSchema(ctx, component.ReflectSchema("ErrorResponse", errorType, errorSchemaDescription, component.component.IOConfig())) if err != nil { return nil, err } - responses["default"] = &openapi.Response{ + openapi.SetResponse(responses, openapi.ResponseDefault, &openapi.Response{ Description: stringPtr("Error response. The view and param may be empty, but one of the message or object should be specified"), Content: map[string]*openapi.MediaType{ ApplicationJson: { Schema: errorSchema, }, - }} + }}) return responses, nil } diff --git a/gateway/router/openapi/openapi3/additional_branches_test.go b/gateway/router/openapi/openapi3/additional_branches_test.go new file mode 100644 index 000000000..312607c18 --- /dev/null +++ b/gateway/router/openapi/openapi3/additional_branches_test.go @@ -0,0 +1,120 @@ +package openapi3 + +import ( + "context" + "encoding/json" + "errors" + "strings" + "testing" +) + +func TestResponsesHelpersAndOperationMarshal(t *testing.T) { + responses := Responses{} + SetResponse(responses, ResponseOK, &Response{Description: strPtr("ok")}) + SetResponse(responses, ResponseDefault, &Response{Description: strPtr("fallback")}) + SetResponse(responses, ResponseCreated, &Response{Description: strPtr("created")}) + + if got, ok := GetResponse(responses, ResponseOK); !ok || got == nil || got.Description == nil || *got.Description != "ok" { + t.Fatalf("expected integer-key lookup to resolve 200 response") + } + if _, ok := GetResponse(responses, string(ResponseOK)); !ok { + t.Fatalf("expected string-key lookup to resolve 200 response") + } + if _, ok := GetResponse(responses, ResponseDefault); !ok { + t.Fatalf("expected default response") + } + if got, ok := GetResponse(responses, ResponseCreated); !ok || got == nil || got.Description == nil || *got.Description != "created" { + t.Fatalf("expected ResponseCode lookup to resolve 201 response") + } + if len(responses) != 3 { + t.Fatalf("expected three responses to be set") + } + + op := &Operation{ + Summary: "sum", + Responses: responses, + Extension: Extension{"x-extra": true}, + } + data, err := json.Marshal(op) + if err != nil { + t.Fatalf("unexpected marshal error: %v", err) + } + if !strings.Contains(string(data), "\""+string(ResponseOK)+"\"") || !strings.Contains(string(data), "x-extra") { + t.Fatalf("expected marshaled operation to include response and extension: %s", string(data)) + } +} + +func assertNormalize[T ResponseKey](t *testing.T, name string, input T, expected ResponseCode) { + t.Helper() + t.Run(name, func(t *testing.T) { + if got := NormalizeResponseCode(input); got != expected { + t.Fatalf("expected %q, got %q", expected, got) + } + }) +} + +func TestNormalizeResponseCode_Table(t *testing.T) { + assertNormalize(t, "string", "default", ResponseCode("default")) + assertNormalize(t, "response code", ResponseDefault, ResponseCode("default")) + assertNormalize(t, "response code literal", ResponseCodeLiteral("200"), ResponseOK) + assertNormalize(t, "int", int(200), ResponseOK) + assertNormalize(t, "int8", int8(101), ResponseCode("101")) + assertNormalize(t, "int16", int16(202), ResponseCode("202")) + assertNormalize(t, "int32", int32(203), ResponseCode("203")) + assertNormalize(t, "int64", int64(204), ResponseCode("204")) + assertNormalize(t, "uint", uint(205), ResponseCode("205")) + assertNormalize(t, "uint8", uint8(206), ResponseCode("206")) + assertNormalize(t, "uint16", uint16(207), ResponseCode("207")) + assertNormalize(t, "uint32", uint32(208), ResponseCode("208")) + assertNormalize(t, "uint64", uint64(209), ResponseCode("209")) +} + +func TestUnmarshalYAMLErrorBranches_Table(t *testing.T) { + tests := []struct { + name string + target yamlUnmarshaller + source interface{} + firstErr error + secondErr error + wantErr string + }{ + {name: "parameter first err", target: &Parameter{}, source: Parameter{}, firstErr: errors.New("p-first"), wantErr: "p-first"}, + {name: "link second err", target: &Link{}, source: Link{}, secondErr: errors.New("l-second"), wantErr: "l-second"}, + {name: "request body second err", target: &RequestBody{}, source: RequestBody{}, secondErr: errors.New("rb-second"), wantErr: "rb-second"}, + {name: "response second err", target: &Response{}, source: Response{}, secondErr: errors.New("resp-second"), wantErr: "resp-second"}, + {name: "security second err", target: &SecurityScheme{}, source: SecurityScheme{}, secondErr: errors.New("sec-second"), wantErr: "sec-second"}, + {name: "schema second err", target: &Schema{}, source: Schema{}, secondErr: errors.New("schema-second"), wantErr: "schema-second"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := tt.target.UnmarshalYAML(context.Background(), yamlDecoder(tt.source, nil, tt.firstErr, tt.secondErr)) + if err == nil || !strings.Contains(err.Error(), tt.wantErr) { + t.Fatalf("expected err containing %q, got %v", tt.wantErr, err) + } + }) + } +} + +func TestOperationResponsesBoundary(t *testing.T) { + t.Run("unmarshal json initializes responses", func(t *testing.T) { + var op Operation + if err := json.Unmarshal([]byte(`{"summary":"s"}`), &op); err != nil { + t.Fatalf("unexpected unmarshal error: %v", err) + } + if op.Responses == nil { + t.Fatalf("expected non-nil responses after unmarshal") + } + }) + + t.Run("unmarshal yaml initializes responses", func(t *testing.T) { + var op Operation + err := op.UnmarshalYAML(context.Background(), yamlDecoder(map[string]interface{}{"summary": "s"}, map[string]interface{}{"x-a": 1}, nil, nil)) + if err != nil { + t.Fatalf("unexpected yaml unmarshal error: %v", err) + } + if op.Responses == nil { + t.Fatalf("expected non-nil responses after yaml unmarshal") + } + }) +} diff --git a/gateway/router/openapi/openapi3/coverage_branches_test.go b/gateway/router/openapi/openapi3/coverage_branches_test.go new file mode 100644 index 000000000..89d3656b9 --- /dev/null +++ b/gateway/router/openapi/openapi3/coverage_branches_test.go @@ -0,0 +1,186 @@ +package openapi3 + +import ( + "encoding/json" + "strings" + "testing" +) + +func decodeSequence(values ...interface{}) func(dest interface{}) error { + index := 0 + return func(dest interface{}) error { + if index >= len(values) { + return nil + } + value := values[index] + index++ + if err, ok := value.(error); ok { + return err + } + data, err := json.Marshal(value) + if err != nil { + return err + } + return json.Unmarshal(data, dest) + } +} + +func TestMarshalNoExtension_Table(t *testing.T) { + trueVal := true + tests := []struct { + name string + value interface{} + }{ + {name: "components", value: &Components{Schemas: Schemas{"Pet": {Type: "object"}}}}, + {name: "parameter", value: &Parameter{Name: "id", In: "query"}}, + {name: "security", value: &SecurityScheme{Type: "http"}}, + {name: "example", value: &Example{Summary: "s"}}, + {name: "server", value: &Server{URL: "http://example"}}, + {name: "server variable", value: &ServerVariable{Default: "dev"}}, + {name: "info", value: &Info{Title: "api", Version: "1.0"}}, + {name: "contact", value: &Contact{Name: "n"}}, + {name: "license", value: &License{Name: "mit"}}, + {name: "tag", value: &Tag{Name: "n"}}, + {name: "path item", value: &PathItem{Summary: "sum"}}, + {name: "encoding", value: &Encoding{ContentType: "application/json", Explode: &trueVal}}, + {name: "request body", value: &RequestBody{Description: "d"}}, + {name: "external", value: &ExternalDocumentation{URL: "http://example"}}, + {name: "response", value: &Response{Description: strPtr("ok")}}, + {name: "media", value: &MediaType{Example: map[string]interface{}{"a": 1}}}, + {name: "operation", value: &Operation{Summary: "sum", Responses: Responses{}}}, + {name: "link", value: &Link{OperationID: "op"}}, + {name: "schema", value: &Schema{Type: "object"}}, + {name: "xml", value: &XML{Name: "node"}}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + data, err := json.Marshal(tt.value) + if err != nil { + t.Fatalf("unexpected marshal error: %v", err) + } + if strings.Contains(string(data), "x-") { + t.Fatalf("did not expect extension key in %s", string(data)) + } + }) + } +} + +func TestYAMLRefAndNonRefBranches(t *testing.T) { + ctx := seedSession() + + t.Run("parameter non ref and ref", func(t *testing.T) { + var nonRef Parameter + if err := nonRef.UnmarshalYAML(ctx, decodeSequence(Parameter{Name: "id", In: "query"})); err != nil { + t.Fatalf("unexpected non-ref error: %v", err) + } + + var refHit Parameter + if err := refHit.UnmarshalYAML(ctx, decodeSequence(map[string]interface{}{"$ref": "#/components/parameters/id"})); err != nil { + t.Fatalf("unexpected ref lookup error: %v", err) + } + }) + + t.Run("link non ref and ref", func(t *testing.T) { + var nonRef Link + if err := nonRef.UnmarshalYAML(ctx, decodeSequence(Link{OperationID: "op"}, map[string]interface{}{"x-a": 1})); err != nil { + t.Fatalf("unexpected non-ref error: %v", err) + } + + var refHit Link + if err := refHit.UnmarshalYAML(ctx, decodeSequence(map[string]interface{}{"$ref": "#/components/links/Self"}, map[string]interface{}{"x-a": 1})); err != nil { + t.Fatalf("unexpected ref lookup error: %v", err) + } + }) + + t.Run("request body non ref and ref", func(t *testing.T) { + var nonRef RequestBody + if err := nonRef.UnmarshalYAML(ctx, decodeSequence(RequestBody{Description: "d"}, map[string]interface{}{"x-a": 1})); err != nil { + t.Fatalf("unexpected non-ref error: %v", err) + } + + var refHit RequestBody + if err := refHit.UnmarshalYAML(ctx, decodeSequence(map[string]interface{}{"$ref": "#/components/requestBodies/Create"}, map[string]interface{}{"x-a": 1})); err != nil { + t.Fatalf("unexpected ref lookup error: %v", err) + } + }) + + t.Run("response non ref and ref", func(t *testing.T) { + var nonRef Response + if err := nonRef.UnmarshalYAML(ctx, decodeSequence(Response{Description: strPtr("ok")}, map[string]interface{}{"x-a": 1})); err != nil { + t.Fatalf("unexpected non-ref error: %v", err) + } + + var refHit Response + if err := refHit.UnmarshalYAML(ctx, decodeSequence(map[string]interface{}{"$ref": "#/components/responses/Default"}, map[string]interface{}{"x-a": 1})); err != nil { + t.Fatalf("unexpected ref lookup error: %v", err) + } + }) + + t.Run("security non ref and ref", func(t *testing.T) { + var nonRef SecurityScheme + if err := nonRef.UnmarshalYAML(ctx, decodeSequence(SecurityScheme{Type: "http"}, map[string]interface{}{"x-a": 1})); err != nil { + t.Fatalf("unexpected non-ref error: %v", err) + } + + var refHit SecurityScheme + if err := refHit.UnmarshalYAML(ctx, decodeSequence(map[string]interface{}{"$ref": "#/components/securitySchemes/Bearer"}, map[string]interface{}{"x-a": 1})); err != nil { + t.Fatalf("unexpected ref lookup error: %v", err) + } + }) + + t.Run("schema non ref and ref", func(t *testing.T) { + var nonRef Schema + if err := nonRef.UnmarshalYAML(ctx, decodeSequence(Schema{Type: "object"}, map[string]interface{}{"x-a": 1})); err != nil { + t.Fatalf("unexpected non-ref error: %v", err) + } + + var refHit Schema + if err := refHit.UnmarshalYAML(ctx, decodeSequence(map[string]interface{}{"$ref": "#/components/schemas/Pet"}, map[string]interface{}{"x-a": 1})); err != nil { + t.Fatalf("unexpected ref lookup error: %v", err) + } + }) +} + +func TestSessionLookupMissingBranches(t *testing.T) { + s := NewSession() + s.Location = "loc" + s.RegisterComponents("loc", &Components{ + Schemas: Schemas{"Pet": {Type: "object"}}, + Parameters: ParametersMap{"id": {Name: "id", In: "query"}}, + Headers: Headers{"/components/headers/Trace": {Name: "Trace", In: "header"}}, + RequestBodies: RequestBodies{"/components/requestBodies/Create": {Description: "create"}}, + Responses: Responses{"/components/responses/Default": {Description: strPtr("default")}}, + SecuritySchemes: SecuritySchemes{"/components/securitySchemes/Bearer": {Type: "http"}}, + Examples: Examples{"/components/examples/Sample": {Summary: "sample"}}, + Links: Links{"/components/links/Self": {OperationID: "self"}}, + Callbacks: Callbacks{"/components/callbacks/Event": {Ref: "eventRef"}}, + }) + + tests := []struct { + name string + lookup func() error + wantErr string + }{ + {name: "parameter missing location", lookup: func() error { _, err := s.LookupParameter("other", "#/components/parameters/id"); return err }, wantErr: "failed to lookup location"}, + {name: "header missing value", lookup: func() error { _, err := s.LookupHeaders("loc", "#/components/headers/Other"); return err }, wantErr: "failed to lookup"}, + {name: "request missing value", lookup: func() error { _, err := s.LookupRequestBody("loc", "#/components/requestBodies/Other"); return err }, wantErr: "failed to lookup"}, + {name: "response missing value", lookup: func() error { _, err := s.LookupResponse("loc", "#/components/responses/Other"); return err }, wantErr: "failed to lookup"}, + {name: "security missing value", lookup: func() error { + _, err := s.LookupSecurityScheme("loc", "#/components/securitySchemes/Other") + return err + }, wantErr: "failed to lookup"}, + {name: "example missing value", lookup: func() error { _, err := s.LookupExample("loc", "#/components/examples/Other"); return err }, wantErr: "failed to lookup"}, + {name: "link missing value", lookup: func() error { _, err := s.LookupLink("loc", "#/components/links/Other"); return err }, wantErr: "failed to lookup"}, + {name: "callback missing value", lookup: func() error { _, err := s.LookupCallback("loc", "#/components/callbacks/Other"); return err }, wantErr: "failed to lookup"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := tt.lookup() + if err == nil || !strings.Contains(err.Error(), tt.wantErr) { + t.Fatalf("expected err containing %q, got %v", tt.wantErr, err) + } + }) + } +} diff --git a/gateway/router/openapi/openapi3/model_methods_test.go b/gateway/router/openapi/openapi3/model_methods_test.go new file mode 100644 index 000000000..ccab0a767 --- /dev/null +++ b/gateway/router/openapi/openapi3/model_methods_test.go @@ -0,0 +1,256 @@ +package openapi3 + +import ( + "context" + "encoding/json" + "errors" + "strings" + "testing" +) + +type yamlUnmarshaller interface { + UnmarshalYAML(ctx context.Context, fn func(dest interface{}) error) error +} + +func yamlDecoder(first interface{}, ext map[string]interface{}, firstErr, secondErr error) func(dest interface{}) error { + call := 0 + return func(dest interface{}) error { + call++ + if call == 1 { + if firstErr != nil { + return firstErr + } + if first == nil { + return nil + } + b, err := json.Marshal(first) + if err != nil { + return err + } + return json.Unmarshal(b, dest) + } + if secondErr != nil { + return secondErr + } + if ext == nil { + ext = map[string]interface{}{} + } + b, err := json.Marshal(ext) + if err != nil { + return err + } + return json.Unmarshal(b, dest) + } +} + +func TestMergeJSON(t *testing.T) { + tests := []struct { + name string + j1 []byte + j2 []byte + expect string + }{ + {name: "empty base", j1: []byte("{}"), j2: []byte(`{"x-a":1}`), expect: `{"x-a":1}`}, + {name: "merged", j1: []byte(`{"a":1}`), j2: []byte(`{"x-a":1}`), expect: `{"a":1,"x-a":1}`}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := string(mergeJSON(tt.j1, tt.j2)); got != tt.expect { + t.Fatalf("expected %s, got %s", tt.expect, got) + } + }) + } +} + +func TestExtensionFunctions(t *testing.T) { + t.Run("unmarshal json keeps x keys", func(t *testing.T) { + ext := Extension{} + if err := ext.UnmarshalJSON([]byte(`{"x-a":1,"a":2}`)); err != nil { + t.Fatalf("unexpected error: %v", err) + } + if _, ok := ext["x-a"]; !ok { + t.Fatalf("expected x-a key") + } + if _, ok := ext["a"]; ok { + t.Fatalf("did not expect non-extension key") + } + }) + + t.Run("custom extension yaml", func(t *testing.T) { + custom := CustomExtension{} + fn := yamlDecoder(map[string]interface{}{"x-a": 1, "a": 2}, nil, nil, nil) + if err := custom.UnmarshalYAML(context.Background(), fn); err != nil { + t.Fatalf("unexpected error: %v", err) + } + if _, ok := custom["x-a"]; !ok { + t.Fatalf("expected x-a") + } + if _, ok := custom["a"]; ok { + t.Fatalf("did not expect a") + } + }) +} + +func TestMarshalJSONWithExtensions_Table(t *testing.T) { + trueValue := true + tests := []struct { + name string + value interface{} + wantErr bool + }{ + {name: "components", value: &Components{Extension: Extension{"x-a": 1}, Schemas: Schemas{"Pet": {Type: "object"}}}}, + {name: "parameter", value: &Parameter{Extension: Extension{"x-a": 1}, Name: "id", In: "query"}}, + {name: "security", value: &SecurityScheme{Extension: Extension{"x-a": 1}, Type: "http"}}, + {name: "oauth flows", value: &OAuthFlows{Extension: Extension{"x-a": 1}, Password: &OAuthFlow{TokenURL: "token", Scopes: map[string]string{"s": "v"}}}}, + {name: "oauth flow", value: &OAuthFlow{Extension: Extension{"x-a": 1}, TokenURL: "token", Scopes: map[string]string{"s": "v"}}}, + {name: "example", value: &Example{Extension: Extension{"x-a": 1}, Summary: "s"}}, + {name: "server", value: &Server{Extension: Extension{"x-a": 1}, URL: "http://example"}}, + {name: "server variable", value: &ServerVariable{Extension: Extension{"x-a": 1}, Default: "dev"}}, + {name: "info", value: &Info{Extension: Extension{"x-a": 1}, Title: "api", Version: "1.0"}}, + {name: "contact", value: &Contact{Extension: Extension{"x-a": 1}, Name: "n"}}, + {name: "license", value: &License{Extension: Extension{"x-a": 1}, Name: "mit"}}, + {name: "tag", value: &Tag{Extension: Extension{"x-a": 1}, Name: "n"}}, + {name: "path item", value: &PathItem{Extension: Extension{"x-a": 1}, Summary: "sum"}}, + {name: "encoding", value: &Encoding{Extension: Extension{"x-a": 1}, ContentType: "application/json", Explode: &trueValue}}, + {name: "request body", value: &RequestBody{Extension: Extension{"x-a": 1}, Description: "d"}}, + {name: "external doc", value: &ExternalDocumentation{Extension: Extension{"x-a": 1}, URL: "http://example"}}, + {name: "response", value: &Response{Extension: Extension{"x-a": 1}, Description: strPtr("ok")}}, + {name: "media", value: &MediaType{Extension: Extension{"x-a": 1}, Example: map[string]interface{}{"a": 1}}}, + {name: "operation", value: &Operation{Extension: Extension{"x-a": 1}, Summary: "sum"}}, + {name: "link", value: &Link{Extension: Extension{"x-a": 1}, OperationID: "op"}}, + {name: "schema", value: &Schema{Extension: Extension{"x-a": 1}, Type: "object"}}, + {name: "xml", value: &XML{Extension: Extension{"x-a": 1}, Name: "node"}}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + data, err := json.Marshal(tt.value) + if tt.wantErr { + if err == nil { + t.Fatalf("expected marshal error") + } + return + } + if err != nil { + t.Fatalf("marshal error: %v", err) + } + if !strings.Contains(string(data), "x-a") { + t.Fatalf("expected extension in json: %s", string(data)) + } + }) + } +} + +func TestUnmarshalJSON_Table(t *testing.T) { + tests := []struct { + name string + target interface{} + json string + }{ + {name: "components", target: &Components{}, json: `{"schemas":{"Pet":{"type":"object"}}}`}, + {name: "parameter", target: &Parameter{}, json: `{"name":"id","in":"query"}`}, + {name: "security", target: &SecurityScheme{}, json: `{"type":"http"}`}, + {name: "oauth flows", target: &OAuthFlows{}, json: `{"password":{"tokenUrl":"token","scopes":{"s":"v"}}}`}, + {name: "oauth flow", target: &OAuthFlow{}, json: `{"tokenUrl":"token","scopes":{"s":"v"}}`}, + {name: "example", target: &Example{}, json: `{"summary":"s"}`}, + {name: "server", target: &Server{}, json: `{"url":"http://example"}`}, + {name: "server variable", target: &ServerVariable{}, json: `{"default":"dev"}`}, + {name: "info", target: &Info{}, json: `{"title":"api","version":"1"}`}, + {name: "contact", target: &Contact{}, json: `{"name":"n"}`}, + {name: "license", target: &License{}, json: `{"name":"mit"}`}, + {name: "tag", target: &Tag{}, json: `{"name":"n"}`}, + {name: "path", target: &PathItem{}, json: `{"summary":"sum"}`}, + {name: "encoding", target: &Encoding{}, json: `{"contentType":"application/json"}`}, + {name: "request", target: &RequestBody{}, json: `{"description":"d"}`}, + {name: "external", target: &ExternalDocumentation{}, json: `{"url":"http://example"}`}, + {name: "response", target: &Response{}, json: `{"description":"ok"}`}, + {name: "media", target: &MediaType{}, json: `{"example":{"a":1}}`}, + {name: "operation", target: &Operation{}, json: `{"summary":"sum"}`}, + {name: "link", target: &Link{}, json: `{"operationId":"op"}`}, + {name: "schema", target: &Schema{}, json: `{"type":"object"}`}, + {name: "xml", target: &XML{}, json: `{"name":"node"}`}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if err := json.Unmarshal([]byte(tt.json), tt.target); err != nil { + t.Fatalf("unmarshal error: %v", err) + } + }) + } +} + +func seedSession() context.Context { + sessionCtx := NewSessionContext(context.Background()) + session := LookupSession(sessionCtx) + session.Location = "main" + session.RegisterComponents("main", &Components{ + Schemas: Schemas{"Pet": {Type: "object"}}, + Parameters: ParametersMap{"id": {Name: "id", In: "query"}}, + Headers: Headers{"/components/headers/Trace": {Name: "Trace", In: "header"}}, + RequestBodies: RequestBodies{"/components/requestBodies/Create": {Description: "create"}}, + Responses: Responses{"/components/responses/Default": {Description: strPtr("default")}}, + SecuritySchemes: SecuritySchemes{"/components/securitySchemes/Bearer": {Type: "http"}}, + Examples: Examples{"/components/examples/Sample": {Summary: "sample"}}, + Links: Links{"/components/links/Self": {OperationID: "self"}}, + Callbacks: Callbacks{"/components/callbacks/Event": {Ref: "inner"}}, + }) + return sessionCtx +} + +func TestUnmarshalYAML_Table(t *testing.T) { + tests := []struct { + name string + target yamlUnmarshaller + source interface{} + ext map[string]interface{} + firstErr error + secondErr error + wantErr string + }{ + {name: "components", target: &Components{}, source: Components{}, ext: map[string]interface{}{"x-a": 1}}, + {name: "parameter ref", target: &Parameter{}, source: Parameter{Ref: "#/components/parameters/id"}}, + {name: "security ref", target: &SecurityScheme{}, source: SecurityScheme{Ref: "#/components/securitySchemes/Bearer"}, ext: map[string]interface{}{"x-a": 1}}, + {name: "oauth flows", target: &OAuthFlows{}, source: OAuthFlows{}, ext: map[string]interface{}{"x-a": 1}}, + {name: "oauth flow", target: &OAuthFlow{}, source: OAuthFlow{}, ext: map[string]interface{}{"x-a": 1}}, + {name: "callback ref", target: &CallbackRef{}, source: CallbackRef{Ref: "#/components/callbacks/Event"}}, + {name: "example", target: &Example{}, source: Example{}, ext: map[string]interface{}{"x-a": 1}}, + {name: "server", target: &Server{}, source: Server{}, ext: map[string]interface{}{"x-a": 1}}, + {name: "server variable", target: &ServerVariable{}, source: ServerVariable{}, ext: map[string]interface{}{"x-a": 1}}, + {name: "info", target: &Info{}, source: Info{}, ext: map[string]interface{}{"x-a": 1}}, + {name: "contact", target: &Contact{}, source: Contact{}, ext: map[string]interface{}{"x-a": 1}}, + {name: "license", target: &License{}, source: License{}, ext: map[string]interface{}{"x-a": 1}}, + {name: "tag", target: &Tag{}, source: Tag{}, ext: map[string]interface{}{"x-a": 1}}, + {name: "path", target: &PathItem{}, source: PathItem{}, ext: map[string]interface{}{"x-a": 1}}, + {name: "encoding", target: &Encoding{}, source: Encoding{}, ext: map[string]interface{}{"x-a": 1}}, + {name: "request body ref", target: &RequestBody{}, source: RequestBody{Ref: "#/components/requestBodies/Create"}, ext: map[string]interface{}{"x-a": 1}}, + {name: "external doc", target: &ExternalDocumentation{}, source: ExternalDocumentation{}, ext: map[string]interface{}{"x-a": 1}}, + {name: "response ref", target: &Response{}, source: Response{Ref: "#/components/responses/Default"}, ext: map[string]interface{}{"x-a": 1}}, + {name: "media", target: &MediaType{}, source: MediaType{}, ext: map[string]interface{}{"x-a": 1}}, + {name: "operation", target: &Operation{}, source: nil, ext: map[string]interface{}{"x-a": 1}}, + {name: "link ref", target: &Link{}, source: Link{Ref: "#/components/links/Self"}, ext: map[string]interface{}{"x-a": 1}}, + {name: "schema ref", target: &Schema{}, source: Schema{Ref: "#/components/schemas/Pet"}, ext: map[string]interface{}{"x-a": 1}}, + {name: "xml", target: &XML{}, source: XML{}, ext: map[string]interface{}{"x-a": 1}}, + {name: "first decoder error", target: &XML{}, firstErr: errors.New("first decoder"), wantErr: "first decoder"}, + {name: "second decoder error", target: &XML{}, source: XML{}, secondErr: errors.New("second decoder"), wantErr: "second decoder"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctx := seedSession() + err := tt.target.UnmarshalYAML(ctx, yamlDecoder(tt.source, tt.ext, tt.firstErr, tt.secondErr)) + if tt.wantErr != "" { + if err == nil || !strings.Contains(err.Error(), tt.wantErr) { + t.Fatalf("expected err containing %q, got %v", tt.wantErr, err) + } + return + } + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + }) + } +} + +func strPtr(v string) *string { return &v } diff --git a/gateway/router/openapi/openapi3/operation.go b/gateway/router/openapi/openapi3/operation.go index 1421f2017..cd37e8a91 100644 --- a/gateway/router/openapi/openapi3/operation.go +++ b/gateway/router/openapi/openapi3/operation.go @@ -50,6 +50,9 @@ func (o *Operation) UnmarshalJSON(b []byte) error { if err != nil { return err } + if tmp.Responses == nil { + tmp.Responses = Responses{} + } *o = Operation(tmp) return o.Extension.UnmarshalJSON(b) } @@ -58,6 +61,9 @@ func (o *Operation) MarshalJSON() ([]byte, error) { type temp Operation tmp := temp(*o) tmp.Extension = nil + if tmp.Responses == nil { + tmp.Responses = Responses{} + } data, err := json.Marshal(tmp) if err != nil { return nil, err @@ -76,7 +82,6 @@ func (o *Operation) MarshalJSON() ([]byte, error) { return res, nil } - func (o *Operation) UnmarshalYAML(ctx context.Context, fn func(dest interface{}) error) error { type temp Operation tmp := temp(*o) @@ -90,6 +95,9 @@ func (o *Operation) UnmarshalYAML(ctx context.Context, fn func(dest interface{}) return err } tmp.Extension = Extension(ext) + if tmp.Responses == nil { + tmp.Responses = Responses{} + } *o = Operation(tmp) return nil } diff --git a/gateway/router/openapi/openapi3/response.go b/gateway/router/openapi/openapi3/response.go index 656065726..47efb2e3d 100644 --- a/gateway/router/openapi/openapi3/response.go +++ b/gateway/router/openapi/openapi3/response.go @@ -3,11 +3,12 @@ package openapi3 import ( "context" "encoding/json" + "fmt" ) // Responses is specified by OpenAPI/Swagger 3.0 standard. type ( - Responses map[interface{}]*Response + Responses map[string]*Response // Response is specified by OpenAPI/Swagger 3.0 standard. Response struct { @@ -20,6 +21,54 @@ type ( } ) +const ( + ResponseContinue ResponseCode = "100" + ResponseOK ResponseCode = "200" + ResponseCreated ResponseCode = "201" + ResponseAccepted ResponseCode = "202" + ResponseNoContent ResponseCode = "204" + ResponseBadRequest ResponseCode = "400" + ResponseUnauthorized ResponseCode = "401" + ResponseForbidden ResponseCode = "403" + ResponseNotFound ResponseCode = "404" + ResponseConflict ResponseCode = "409" + ResponseUnprocessable ResponseCode = "422" + ResponseInternalServerErr ResponseCode = "500" + ResponseBadGateway ResponseCode = "502" + ResponseServiceUnavailable ResponseCode = "503" + ResponseDefault ResponseCode = "default" +) + +type ( + ResponseCode string + ResponseCodeLiteral string + + ResponseKey interface { + ~string | ~int | ~int8 | ~int16 | ~int32 | ~int64 | ~uint | ~uint8 | ~uint16 | ~uint32 | ~uint64 + } +) + +func NormalizeResponseCode[T ResponseKey](code T) ResponseCode { + return ResponseCode(fmt.Sprintf("%v", code)) +} + +func SetResponse[T ResponseKey](r Responses, code T, response *Response) { + key := NormalizeResponseCode(code) + if key == "" { + return + } + r[string(key)] = response +} + +func GetResponse[T ResponseKey](r Responses, code T) (*Response, bool) { + key := NormalizeResponseCode(code) + if key == "" { + return nil, false + } + value, ok := r[string(key)] + return value, ok +} + func (r *Response) UnmarshalJSON(b []byte) error { type temp Response var tmp = temp{} diff --git a/gateway/router/openapi/openapi3/session.go b/gateway/router/openapi/openapi3/session.go index 741d4c72b..8c29c5a30 100644 --- a/gateway/router/openapi/openapi3/session.go +++ b/gateway/router/openapi/openapi3/session.go @@ -38,7 +38,7 @@ func (s *Session) RegisterComponents(location string, components *Components) { components.RequestBodies = map[string]*RequestBody{} } if len(components.Responses) == 0 { - components.Responses = map[interface{}]*Response{} + components.Responses = map[string]*Response{} } if len(components.SecuritySchemes) == 0 { components.SecuritySchemes = map[string]*SecurityScheme{} @@ -60,8 +60,7 @@ func (s *Session) RegisterComponents(location string, components *Components) { // LookupSchema lookups schema func (s *Session) LookupSchema(location string, ref string) (*Schema, error) { - switch ref[0] { - case '#': + if len(ref) > 0 && ref[0] == '#' { id := s.normalizeRef(ref[1:], "/components/schemas/") components, ok := s.components[location] if !ok { @@ -74,20 +73,13 @@ func (s *Session) LookupSchema(location string, ref string) (*Schema, error) { result := *value result.Ref = ref return &result, nil - case '.': - - case '/': - - default: - } return nil, fmt.Errorf("unsupported: %v, at %v", ref, location) } // LookupParameter lookup parameters func (s *Session) LookupParameter(location string, ref string) (*Parameter, error) { - switch ref[0] { - case '#': + if len(ref) > 0 && ref[0] == '#' { id := s.normalizeRef(ref[1:], "/components/parameters/") components, ok := s.components[location] if !ok { @@ -100,20 +92,13 @@ func (s *Session) LookupParameter(location string, ref string) (*Parameter, erro result := *value result.Ref = ref return &result, nil - case '.': - - case '/': - - default: - } return nil, fmt.Errorf("unsupported: %v, at %v", ref, location) } // LookupHeaders lookup headers func (s *Session) LookupHeaders(location string, ref string) (*Parameter, error) { - switch ref[0] { - case '#': + if len(ref) > 0 && ref[0] == '#' { components, ok := s.components[location] if !ok { return nil, fmt.Errorf("failed to lookup location: %v", location) @@ -125,20 +110,13 @@ func (s *Session) LookupHeaders(location string, ref string) (*Parameter, error) result := *value result.Ref = ref return &result, nil - case '.': - - case '/': - - default: - } return nil, fmt.Errorf("unsupported: %v, at %v", ref, location) } // LookupRequestBody lookup request body func (s *Session) LookupRequestBody(location string, ref string) (*RequestBody, error) { - switch ref[0] { - case '#': + if len(ref) > 0 && ref[0] == '#' { components, ok := s.components[location] if !ok { return nil, fmt.Errorf("failed to lookup location: %v", location) @@ -150,20 +128,13 @@ func (s *Session) LookupRequestBody(location string, ref string) (*RequestBody, result := *value result.Ref = ref return &result, nil - case '.': - - case '/': - - default: - } return nil, fmt.Errorf("unsupported: %v, at %v", ref, location) } // LookupResponse lookup response func (s *Session) LookupResponse(location string, ref string) (*Response, error) { - switch ref[0] { - case '#': + if len(ref) > 0 && ref[0] == '#' { components, ok := s.components[location] if !ok { return nil, fmt.Errorf("failed to lookup location: %v", location) @@ -175,20 +146,13 @@ func (s *Session) LookupResponse(location string, ref string) (*Response, error) result := *value result.Ref = ref return &result, nil - case '.': - - case '/': - - default: - } return nil, fmt.Errorf("unsupported: %v, at %v", ref, location) } // LookupSecurityScheme lookup security scheme func (s *Session) LookupSecurityScheme(location string, ref string) (*SecurityScheme, error) { - switch ref[0] { - case '#': + if len(ref) > 0 && ref[0] == '#' { components, ok := s.components[location] if !ok { return nil, fmt.Errorf("failed to lookup location: %v", location) @@ -200,20 +164,13 @@ func (s *Session) LookupSecurityScheme(location string, ref string) (*SecuritySc result := *value result.Ref = ref return &result, nil - case '.': - - case '/': - - default: - } return nil, fmt.Errorf("unsupported: %v, at %v", ref, location) } // LookupExample lookup example func (s *Session) LookupExample(location string, ref string) (*Example, error) { - switch ref[0] { - case '#': + if len(ref) > 0 && ref[0] == '#' { components, ok := s.components[location] if !ok { return nil, fmt.Errorf("failed to lookup location: %v", location) @@ -225,20 +182,13 @@ func (s *Session) LookupExample(location string, ref string) (*Example, error) { result := *value result.Ref = ref return &result, nil - case '.': - - case '/': - - default: - } return nil, fmt.Errorf("unsupported: %v, at %v", ref, location) } // LookupLink lookup link func (s *Session) LookupLink(location string, ref string) (*Link, error) { - switch ref[0] { - case '#': + if len(ref) > 0 && ref[0] == '#' { components, ok := s.components[location] if !ok { return nil, fmt.Errorf("failed to lookup location: %v", location) @@ -250,20 +200,13 @@ func (s *Session) LookupLink(location string, ref string) (*Link, error) { result := *value result.Ref = ref return &result, nil - case '.': - - case '/': - - default: - } return nil, fmt.Errorf("unsupported: %v, at %v", ref, location) } // LookupLink lookup callback func (s *Session) LookupCallback(location string, ref string) (*CallbackRef, error) { - switch ref[0] { - case '#': + if len(ref) > 0 && ref[0] == '#' { components, ok := s.components[location] if !ok { return nil, fmt.Errorf("failed to lookup location: %v", location) @@ -275,12 +218,6 @@ func (s *Session) LookupCallback(location string, ref string) (*CallbackRef, err result := *value result.Ref = ref return &result, nil - case '.': - - case '/': - - default: - } return nil, fmt.Errorf("unsupported: %v, at %v", ref, location) } diff --git a/gateway/router/openapi/openapi3/session_test.go b/gateway/router/openapi/openapi3/session_test.go new file mode 100644 index 000000000..4de4a2672 --- /dev/null +++ b/gateway/router/openapi/openapi3/session_test.go @@ -0,0 +1,152 @@ +package openapi3 + +import ( + "context" + "errors" + "strings" + "testing" +) + +func TestSessionRegisterAndLookup_Table(t *testing.T) { + s := NewSession() + s.Location = "loc" + s.RegisterComponents("loc", &Components{ + Schemas: Schemas{"Pet": {Type: "object"}}, + Parameters: ParametersMap{"id": {Name: "id", In: "query"}}, + Headers: Headers{"/components/headers/Trace": {Name: "Trace", In: "header"}}, + RequestBodies: RequestBodies{"/components/requestBodies/Create": {Description: "create"}}, + Responses: Responses{"/components/responses/Default": {Description: stringRef("default")}}, + SecuritySchemes: SecuritySchemes{"/components/securitySchemes/Bearer": {Type: "http"}}, + Examples: Examples{"/components/examples/Sample": {Summary: "sample"}}, + Links: Links{"/components/links/Self": {OperationID: "self"}}, + Callbacks: Callbacks{"/components/callbacks/Event": {Ref: "eventRef"}}, + }) + + tests := []struct { + name string + lookup func() (interface{}, error) + wantErr string + assert func(t *testing.T, got interface{}) + }{ + {name: "lookup schema", lookup: func() (interface{}, error) { return s.LookupSchema("loc", "#/components/schemas/Pet") }, assert: func(t *testing.T, got interface{}) { + if got.(*Schema).Ref == "" { + t.Fatalf("missing ref") + } + }}, + {name: "lookup parameter", lookup: func() (interface{}, error) { return s.LookupParameter("loc", "#/components/parameters/id") }, assert: func(t *testing.T, got interface{}) { + if got.(*Parameter).Name != "id" { + t.Fatalf("name mismatch") + } + }}, + {name: "lookup header", lookup: func() (interface{}, error) { return s.LookupHeaders("loc", "#/components/headers/Trace") }, assert: func(t *testing.T, got interface{}) { + if got.(*Parameter).In != "header" { + t.Fatalf("in mismatch") + } + }}, + {name: "lookup request body", lookup: func() (interface{}, error) { return s.LookupRequestBody("loc", "#/components/requestBodies/Create") }, assert: func(t *testing.T, got interface{}) { + if got.(*RequestBody).Description != "create" { + t.Fatalf("desc mismatch") + } + }}, + {name: "lookup response", lookup: func() (interface{}, error) { return s.LookupResponse("loc", "#/components/responses/Default") }, assert: func(t *testing.T, got interface{}) { + if got.(*Response).Description == nil { + t.Fatalf("desc missing") + } + }}, + {name: "lookup security", lookup: func() (interface{}, error) { + return s.LookupSecurityScheme("loc", "#/components/securitySchemes/Bearer") + }, assert: func(t *testing.T, got interface{}) { + if got.(*SecurityScheme).Type != "http" { + t.Fatalf("type mismatch") + } + }}, + {name: "lookup example", lookup: func() (interface{}, error) { return s.LookupExample("loc", "#/components/examples/Sample") }, assert: func(t *testing.T, got interface{}) { + if got.(*Example).Summary != "sample" { + t.Fatalf("summary mismatch") + } + }}, + {name: "lookup link", lookup: func() (interface{}, error) { return s.LookupLink("loc", "#/components/links/Self") }, assert: func(t *testing.T, got interface{}) { + if got.(*Link).OperationID != "self" { + t.Fatalf("op mismatch") + } + }}, + {name: "lookup callback", lookup: func() (interface{}, error) { return s.LookupCallback("loc", "#/components/callbacks/Event") }, assert: func(t *testing.T, got interface{}) { + if got.(*CallbackRef).Ref != "#/components/callbacks/Event" { + t.Fatalf("ref mismatch") + } + }}, + {name: "missing location", lookup: func() (interface{}, error) { return s.LookupSchema("other", "#/components/schemas/Pet") }, wantErr: "failed to lookup location"}, + {name: "missing value", lookup: func() (interface{}, error) { return s.LookupParameter("loc", "#/components/parameters/other") }, wantErr: "failed to lookup"}, + {name: "unsupported ref", lookup: func() (interface{}, error) { return s.LookupSchema("loc", "./components/schemas/Pet") }, wantErr: "unsupported"}, + {name: "unsupported parameter ref", lookup: func() (interface{}, error) { return s.LookupParameter("loc", "./components/parameters/id") }, wantErr: "unsupported"}, + {name: "unsupported header ref", lookup: func() (interface{}, error) { return s.LookupHeaders("loc", "./components/headers/Trace") }, wantErr: "unsupported"}, + {name: "unsupported request body ref", lookup: func() (interface{}, error) { return s.LookupRequestBody("loc", "./components/requestBodies/Create") }, wantErr: "unsupported"}, + {name: "unsupported response ref", lookup: func() (interface{}, error) { return s.LookupResponse("loc", "./components/responses/Default") }, wantErr: "unsupported"}, + {name: "unsupported security ref", lookup: func() (interface{}, error) { + return s.LookupSecurityScheme("loc", "./components/securitySchemes/Bearer") + }, wantErr: "unsupported"}, + {name: "unsupported example ref", lookup: func() (interface{}, error) { return s.LookupExample("loc", "./components/examples/Sample") }, wantErr: "unsupported"}, + {name: "unsupported link ref", lookup: func() (interface{}, error) { return s.LookupLink("loc", "./components/links/Self") }, wantErr: "unsupported"}, + {name: "unsupported callback ref", lookup: func() (interface{}, error) { return s.LookupCallback("loc", "./components/callbacks/Event") }, wantErr: "unsupported"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := tt.lookup() + if tt.wantErr != "" { + if err == nil || !strings.Contains(err.Error(), tt.wantErr) { + t.Fatalf("expected error containing %q, got %v", tt.wantErr, err) + } + return + } + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if tt.assert != nil { + tt.assert(t, got) + } + }) + } +} + +func TestSessionHelpers(t *testing.T) { + t.Run("add defer and close", func(t *testing.T) { + s := NewSession() + order := 0 + s.AddDefer(func() error { order++; return nil }) + s.AddDefer(func() error { order++; return nil }) + if err := s.Close(); err != nil { + t.Fatalf("unexpected error: %v", err) + } + if order != 2 { + t.Fatalf("expected both defers, got %d", order) + } + }) + + t.Run("close returns defer error", func(t *testing.T) { + s := NewSession() + s.AddDefer(func() error { return errors.New("boom") }) + if err := s.Close(); err == nil || err.Error() != "boom" { + t.Fatalf("expected boom, got %v", err) + } + }) + + t.Run("normalize ref", func(t *testing.T) { + s := NewSession() + if got := s.normalizeRef("/components/schemas/Pet", "/components/schemas/"); got != "Pet" { + t.Fatalf("unexpected normalize result: %s", got) + } + }) + + t.Run("lookup session from context", func(t *testing.T) { + ctx := NewSessionContext(context.Background()) + if LookupSession(ctx) == nil { + t.Fatalf("expected session in context") + } + if LookupSession(context.Background()) != nil { + t.Fatalf("expected nil session") + } + }) +} + +func stringRef(v string) *string { return &v } diff --git a/gateway/router/openapi/schema.go b/gateway/router/openapi/schema.go index ef707ec06..3eef32319 100644 --- a/gateway/router/openapi/schema.go +++ b/gateway/router/openapi/schema.go @@ -2,23 +2,18 @@ package openapi import ( "context" - "fmt" "github.com/viant/datly/gateway/router/marshal/config" "github.com/viant/datly/gateway/router/openapi/openapi3" - "github.com/viant/datly/internal/setter" "github.com/viant/datly/repository" "github.com/viant/datly/repository/contract" - "github.com/viant/datly/view" + "github.com/viant/datly/utils/types" "github.com/viant/datly/view/state" "github.com/viant/datly/view/tags" "github.com/viant/tagly/format/text" - ftime "github.com/viant/tagly/format/time" "github.com/viant/xdatly/docs" "github.com/viant/xreflect" "reflect" - "strings" "sync" - "time" ) const ( @@ -53,6 +48,7 @@ type ( schemas []*openapi3.Schema index map[string]int generatedSchemas map[string]*openapi3.Schema + visitingTypes map[string]int } ) @@ -95,6 +91,7 @@ func NewContainer() *SchemaContainer { return &SchemaContainer{ index: map[string]int{}, generatedSchemas: map[string]*openapi3.Schema{}, + visitingTypes: map[string]int{}, } } @@ -223,9 +220,12 @@ func (c *ComponentSchema) ReflectSchema(name string, rType reflect.Type, descrip func (c *ComponentSchema) SchemaWithTag(fieldName string, rType reflect.Type, description string, ioConfig *config.IOConfig, tag Tag) *Schema { if parameter := tag.Parameter; parameter != nil { - if parameter.DataType != "" { - typeLookup := c.component.View.Resource().LookupType() - if lType, _ := typeLookup(parameter.DataType); lType != nil { + if tag.IsInput && parameter.DataType != "" { + var typeLookup xreflect.LookupType + if c.component != nil && c.component.View != nil && c.component.View.Resource() != nil { + typeLookup = c.component.View.Resource().LookupType() + } + if lType, _ := types.LookupType(typeLookup, parameter.DataType); lType != nil { rType = lType } } @@ -245,6 +245,7 @@ func (c *ComponentSchema) SchemaWithTag(fieldName string, rType reflect.Type, de docs: c.component.Docs(), } } + func (c *ComponentSchema) GenerateSchema(ctx context.Context, schema *Schema) (*openapi3.Schema, error) { description, err := c.Description(ctx, schema.path, "") if err != nil { @@ -273,294 +274,3 @@ func (c *ComponentSchema) GenerateSchema(ctx context.Context, schema *Schema) (* return result, nil } - -// TODO refactor -func (c *SchemaContainer) addToSchema(ctx context.Context, component *ComponentSchema, dst *openapi3.Schema, schema *Schema) error { - rType := schema.rType - for rType.Kind() == reflect.Ptr { - rType = rType.Elem() - } - - if schema.tag.Example != "" { - dst.Example = schema.tag.Example - } - - rootTable := "" - - if component.component.View.Mode == view.ModeQuery { - rootTable = component.component.View.Table - } - switch rType.Kind() { - case reflect.Slice, reflect.Array: - var err error - dst.Items, err = c.createSchema(ctx, component, schema.SliceItem(rType)) - if err != nil { - return err - } - dst.Type = arrayOutput - case reflect.Struct: - if rType == xreflect.TimeType { - dst.Type = stringOutput - timeLayout := schema.tag._tag.TimeLayout - if timeLayout == "" { - timeLayout = time.RFC3339 - } - - var dateFormat string - if containsAny(timeLayout, "15", "04", "05") { - dateFormat = "date-time" - } else { - dateFormat = "date" - } - - dst.Format = dateFormat - if dst.Example == nil { - dst.Example = time.Now().Format(timeLayout) - } - - dst.Pattern = ftime.TimeLayoutToDateFormat(timeLayout) - break - } - - dst.Properties = openapi3.Schemas{} - dst.Type = objectOutput - numField := rType.NumField() - table := schema.tag.Table - for i := 0; i < numField; i++ { - aField := rType.Field(i) - if aField.PkgPath != "" { - continue - } - aTag, err := ParseTag(aField, aField.Tag, schema.isInput, rootTable) - if err != nil { - return err - } - if aTag.Table == "" { - aTag.Table = table - } - if aTag.Ignore { - continue - } - - if aTag.Column != "" && table == "" { - table = rootTable - aTag.Table = rootTable - } - if table != "" && aTag.Column == "" { - aTag.Column = text.DetectCaseFormat(aField.Name).To(text.CaseFormatUpperUnderscore).Format(aField.Name) - } - - if aTag.Inlined { - dst.AdditionalPropertiesAllowed = setter.BoolPtr(true) - continue - } - fieldSchema, err := schema.Field(aField, aTag) - if err != nil { - return err - } - - if component.component.Output.IsExcluded(fieldSchema.path) { - continue - } - - docs := component.component.Docs() - updatedDocumentation(aTag, docs, fieldSchema) - - if aField.Anonymous { - if err := c.addToSchema(ctx, component, dst, fieldSchema); err != nil { - return err - } - continue - } - - if len(dst.Properties) == 0 { - dst.Properties = make(openapi3.Schemas) - } - dst.Properties[fieldSchema.fieldName], err = c.createSchema(ctx, component, fieldSchema) - if err != nil { - return err - } - - if !aTag.IsNullable { - dst.Required = append(dst.Required, fieldSchema.fieldName) - } - } - default: - if rType.Kind() == reflect.Interface { - dst.Type = objectOutput - break - } - - if rType.Kind() == reflect.Map { - dst.Type = objectOutput - keyType := rType.Key() - valueType := rType.Elem() - valueTypeName := valueType.Name() - vType, format, err := c.toOpenApiType(valueType) - valueSchema := &openapi3.Schema{ - Type: vType, - Format: format, - } - if err != nil { - switch valueType.Kind() { - case reflect.Struct: - case reflect.Slice: - - if vType, format, err = c.toOpenApiType(valueType.Elem()); err != nil { - return err - } - valueTypeName += strings.Title(valueType.Elem().Name()) + "s" - valueSchema.Type = arrayOutput - valueSchema.Items = &openapi3.Schema{ - Type: vType, - Format: format, - } - default: - return err - } - } - dst.Properties = openapi3.Schemas{} - mapType := strings.Title(keyType.Name()) + valueTypeName + "Map" - dst.Properties[mapType] = valueSchema - break - } - - var err error - dst.Type, dst.Format, err = c.toOpenApiType(rType) - if err != nil { - return err - } - } - - return nil -} - -func updatedDocumentation(aTag *Tag, docs *state.Docs, fieldSchema *Schema) { - if docs == nil { - return - } - if aTag.Column != "" && len(docs.Columns) > 0 { - columns := docs.Columns - if aTag.Description == "" { - aTag.Description, _ = columns.ColumnDescription(aTag.Table, aTag.Column) - } - if aTag.Description == "" { - aTag.Description, _ = columns.ColumnDescription("", aTag.Column) - } - if aTag.Example == "" { - aTag.Example, _ = columns.ColumnExample(aTag.Table, aTag.Column) - } - } - if aTag.Description == "" && len(docs.Paths) > 0 { - if desc, ok := docs.Paths.ByName(fieldSchema.path); ok { - aTag.Description = desc - } else if desc, ok := docs.Paths.ByName(fieldSchema.name); ok { - aTag.Description = desc - fieldSchema.description = desc - } - } - if aTag.Description != "" { - fieldSchema.description = aTag.Description - } - if aTag.Example != "" { - fieldSchema.example = aTag.Example - } - -} - -func containsAny(format string, values ...string) bool { - for _, value := range values { - if strings.Contains(format, value) { - return true - } - } - - return false -} - -func (c *ComponentSchema) GetOrGenerateSchema(ctx context.Context, schema *Schema) (*openapi3.Schema, error) { - return c.schemas.CreateSchema(ctx, c, schema) -} - -func (c *SchemaContainer) CreateSchema(ctx context.Context, componentSchema *ComponentSchema, fieldSchema *Schema) (*openapi3.Schema, error) { - c.mux.Lock() - defer c.mux.Unlock() - - return c.createSchema(ctx, componentSchema, fieldSchema) -} - -func (c *SchemaContainer) createSchema(ctx context.Context, componentSchema *ComponentSchema, fieldSchema *Schema) (*openapi3.Schema, error) { - description, err := componentSchema.Description(ctx, fieldSchema.path, fieldSchema.description) - if err != nil { - return nil, err - } - example, err := componentSchema.Example(ctx, fieldSchema.path, fieldSchema.example) - if err != nil { - return nil, err - } - - if fieldSchema.tag.TypeName != "" { - _, ok := c.generatedSchemas[fieldSchema.tag.TypeName] - if ok { - return c.SchemaRef(fieldSchema.tag.TypeName, description), nil - } - } - - apiType, format, ok := c.asOpenApiType(fieldSchema.rType) - if ok { - return &openapi3.Schema{ - Type: apiType, - Format: format, - Description: description, - Example: example, - }, nil - } - - schema, err := componentSchema.GenerateSchema(ctx, fieldSchema) - if err != nil { - return nil, err - } - - if fieldSchema.tag.TypeName != "" { - c.generatedSchemas[fieldSchema.tag.TypeName] = schema - c.schemas = append(c.schemas, schema) - schema = c.SchemaRef(fieldSchema.tag.TypeName, description) - } - - return schema, err -} - -func (c *SchemaContainer) SchemaRef(schemaName string, description string) *openapi3.Schema { - return &openapi3.Schema{ - Ref: "#/components/schemas/" + schemaName, - Description: description, - } -} - -func (c *SchemaContainer) toOpenApiType(rType reflect.Type) (string, string, error) { - apiType, format, ok := c.asOpenApiType(rType) - if !ok { - return empty, empty, fmt.Errorf("unsupported openapi3 type %v", rType.String()) - } - return apiType, format, nil -} - -func (c *SchemaContainer) asOpenApiType(rType reflect.Type) (string, string, bool) { - if rType.Kind() == reflect.Ptr { - rType = rType.Elem() - } - switch rType.Kind() { - case reflect.Int, reflect.Int64, reflect.Uint, reflect.Uint64: - return integerOutput, int64Format, true - case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint8, reflect.Uint16, reflect.Uint32: - return integerOutput, int32Format, true - case reflect.Float64, reflect.Float32: - return numberOutput, doubleFormat, true - case reflect.Bool: - return booleanOutput, empty, true - case reflect.String: - return stringOutput, empty, true - } - - return empty, empty, false -} diff --git a/gateway/router/openapi/schema_build.go b/gateway/router/openapi/schema_build.go new file mode 100644 index 000000000..e51b27450 --- /dev/null +++ b/gateway/router/openapi/schema_build.go @@ -0,0 +1,608 @@ +package openapi + +import ( + "context" + "fmt" + "github.com/viant/datly/internal/setter" + "github.com/viant/datly/view" + "github.com/viant/datly/view/state" + "github.com/viant/tagly/format/text" + ftime "github.com/viant/tagly/format/time" + "github.com/viant/xreflect" + "os" + "reflect" + "sort" + "strings" + "time" + + "github.com/viant/datly/gateway/router/openapi/openapi3" +) + +func (c *SchemaContainer) addToSchema(ctx context.Context, component *ComponentSchema, dst *openapi3.Schema, schema *Schema) error { + rType := dereferenceType(schema.rType) + applySchemaExample(dst, schema) + + switch rType.Kind() { + case reflect.Slice, reflect.Array: + return c.addArraySchema(ctx, component, dst, schema, rType) + case reflect.Struct: + return c.addStructSchema(ctx, component, dst, schema, rType) + default: + return c.addDefaultSchema(ctx, component, dst, schema, rType) + } +} + +func recursionTypeKey(rType reflect.Type) string { + rType = dereferenceType(rType) + if rType == nil { + return "" + } + switch rType.Kind() { + case reflect.Struct, reflect.Interface, reflect.Slice, reflect.Array, reflect.Map: + return rType.PkgPath() + ":" + rType.String() + default: + return "" + } +} + +func (c *SchemaContainer) addArraySchema(ctx context.Context, component *ComponentSchema, dst *openapi3.Schema, schema *Schema, rType reflect.Type) error { + itemSchema, err := c.createSchema(ctx, component, schema.SliceItem(rType)) + if err != nil { + return err + } + dst.Type = arrayOutput + dst.Items = itemSchema + return nil +} + +func (c *SchemaContainer) addStructSchema(ctx context.Context, component *ComponentSchema, dst *openapi3.Schema, schema *Schema, rType reflect.Type) error { + if rType == xreflect.TimeType { + addTimeSchema(dst, schema) + return nil + } + if selfKey := recursionTypeKey(rType); selfKey != "" { + c.visitingTypes[selfKey]++ + defer func() { + c.visitingTypes[selfKey]-- + if c.visitingTypes[selfKey] == 0 { + delete(c.visitingTypes, selfKey) + } + }() + } + + dst.Type = objectOutput + dst.Properties = openapi3.Schemas{} + rootTable := rootTable(component) + table := schema.tag.Table + + for i := 0; i < rType.NumField(); i++ { + field := rType.Field(i) + if shouldSkipStructField(field) { + continue + } + + aTag, err := ParseTag(field, field.Tag, schema.isInput, rootTable) + if err != nil { + return err + } + if normalizeFieldTag(aTag, field.Name, rootTable, table) { + table = aTag.Table + } + if shouldSkipByTag(component, aTag) { + continue + } + if aTag.Inlined { + dst.AdditionalPropertiesAllowed = setter.BoolPtr(true) + continue + } + + fieldSchema, err := schema.Field(field, aTag) + if err != nil { + return err + } + if component.component.Output.IsExcluded(fieldSchema.path) { + continue + } + + updatedDocumentation(aTag, component.component.Docs(), fieldSchema) + if field.Anonymous { + if childKey := recursionTypeKey(fieldSchema.rType); childKey != "" && c.visitingTypes[childKey] > 0 { + // Avoid anonymous self/embed loops while preserving named-schema recursion via createSchema. + continue + } + if err := c.addToSchema(ctx, component, dst, fieldSchema); err != nil { + return err + } + continue + } + + childSchema, err := c.createSchema(ctx, component, fieldSchema) + if err != nil { + return err + } + dst.Properties[fieldSchema.fieldName] = childSchema + if !aTag.IsNullable { + dst.Required = append(dst.Required, fieldSchema.fieldName) + } + } + + return nil +} + +func (c *SchemaContainer) addDefaultSchema(ctx context.Context, component *ComponentSchema, dst *openapi3.Schema, schema *Schema, rType reflect.Type) error { + switch rType.Kind() { + case reflect.Interface: + return c.addInterfaceSchema(ctx, component, dst, schema, rType) + case reflect.Map: + return c.addMapSchema(ctx, component, dst, schema, rType) + default: + apiType, format, err := c.toOpenApiType(rType) + if err != nil { + return err + } + dst.Type = apiType + dst.Format = format + return nil + } +} + +func (c *SchemaContainer) addInterfaceSchema(ctx context.Context, component *ComponentSchema, dst *openapi3.Schema, schema *Schema, interfaceType reflect.Type) error { + dst.Type = objectOutput + variants, skipped, err := c.interfaceVariants(ctx, component, schema, interfaceType) + if err != nil { + return err + } + if len(skipped) > 0 { + if shouldFailOnPolymorphismSkip() { + return fmt.Errorf("failed to resolve polymorphic variants for %s: %s", interfaceType.String(), strings.Join(skipped, ",")) + } + if dst.Extension == nil { + dst.Extension = openapi3.Extension{} + } + dst.Extension["x-datly-polymorphism-skipped"] = skipped + dst.Extension["x-datly-polymorphism-mode"] = "best-effort" + } + if len(variants) > 0 { + dst.OneOf = variants + if discriminator := oneOfDiscriminator(variants); discriminator != nil { + dst.Discriminator = discriminator + c.applyDiscriminatorToVariants(discriminator) + } + } + return nil +} + +func (c *SchemaContainer) interfaceVariants(ctx context.Context, component *ComponentSchema, schema *Schema, interfaceType reflect.Type) (openapi3.SchemaList, []string, error) { + if component == nil || component.component == nil { + return nil, nil, nil + } + registry := component.component.TypeRegistry() + if registry == nil { + return nil, nil, nil + } + + packageNames := registry.PackageNames() + sort.Strings(packageNames) + + seenByType := map[string]bool{} + result := make(openapi3.SchemaList, 0) + var skipped []string + for _, packageName := range packageNames { + pkg := registry.Package(packageName) + if pkg == nil { + continue + } + + typeNames := pkg.TypeNames() + sort.Strings(typeNames) + for _, typeName := range typeNames { + candidateType, err := pkg.Lookup(typeName) + if err != nil || candidateType == nil { + continue + } + candidateType = dereferenceType(candidateType) + if !implementsInterface(candidateType, interfaceType) { + continue + } + if candidateType.Kind() == reflect.Interface { + continue + } + + key := candidateType.String() + if seenByType[key] { + continue + } + seenByType[key] = true + + typeLabel := typeName + if typeLabel == "" { + typeLabel = candidateType.String() + } + + variantSchema := &Schema{ + docs: schema.docs, + pkg: schema.pkg, + path: key, + fieldName: typeLabel, + name: typeLabel, + description: schema.description, + example: schema.example, + rType: candidateType, + tag: Tag{}, + ioConfig: schema.ioConfig, + isInput: schema.isInput, + } + variantSchema.tag.TypeName = typeLabel + + builtSchema, err := c.createSchema(ctx, component, variantSchema) + if err != nil { + skipped = append(skipped, typeLabel) + continue + } + if builtSchema.Ref == "" { + skipped = append(skipped, typeLabel) + continue + } + result = append(result, builtSchema) + } + } + return result, skipped, nil +} + +func (c *SchemaContainer) addMapSchema(ctx context.Context, component *ComponentSchema, dst *openapi3.Schema, schema *Schema, rType reflect.Type) error { + valueSchema, err := c.mapValueSchema(ctx, component, schema, rType.Elem()) + if err != nil { + return err + } + dst.Type = objectOutput + dst.AdditionalProperties = valueSchema + return nil +} + +func (c *SchemaContainer) mapValueSchema(ctx context.Context, component *ComponentSchema, parent *Schema, valueType reflect.Type) (*openapi3.Schema, error) { + valueType = dereferenceType(valueType) + if apiType, format, ok := c.asOpenApiType(valueType); ok { + return &openapi3.Schema{Type: apiType, Format: format}, nil + } + + switch valueType.Kind() { + case reflect.Slice, reflect.Array: + itemsSchema, err := c.mapValueSchema(ctx, component, parent, valueType.Elem()) + if err != nil { + return nil, err + } + return &openapi3.Schema{Type: arrayOutput, Items: itemsSchema}, nil + default: + valueFieldSchema := &Schema{ + docs: parent.docs, + pkg: parent.pkg, + path: parent.path + ".value", + fieldName: parent.fieldName, + name: parent.name, + description: parent.description, + example: parent.example, + rType: valueType, + tag: Tag{}, + ioConfig: parent.ioConfig, + isInput: parent.isInput, + } + if valueType.Name() != "" { + valueFieldSchema.tag.TypeName = valueType.Name() + } + return c.createSchema(ctx, component, valueFieldSchema) + } +} + +func (c *ComponentSchema) GetOrGenerateSchema(ctx context.Context, schema *Schema) (*openapi3.Schema, error) { + return c.schemas.CreateSchema(ctx, c, schema) +} + +func (c *SchemaContainer) CreateSchema(ctx context.Context, componentSchema *ComponentSchema, fieldSchema *Schema) (*openapi3.Schema, error) { + c.mux.Lock() + defer c.mux.Unlock() + + return c.createSchema(ctx, componentSchema, fieldSchema) +} + +func (c *SchemaContainer) createSchema(ctx context.Context, componentSchema *ComponentSchema, fieldSchema *Schema) (*openapi3.Schema, error) { + description, err := componentSchema.Description(ctx, fieldSchema.path, fieldSchema.description) + if err != nil { + return nil, err + } + example, err := componentSchema.Example(ctx, fieldSchema.path, fieldSchema.example) + if err != nil { + return nil, err + } + + if fieldSchema.tag.TypeName != "" { + if _, ok := c.generatedSchemas[fieldSchema.tag.TypeName]; ok { + return c.SchemaRef(fieldSchema.tag.TypeName, description), nil + } + } + + if apiType, format, ok := c.asOpenApiType(fieldSchema.rType); ok { + return &openapi3.Schema{ + Type: apiType, + Format: format, + Description: description, + Example: example, + }, nil + } + + // Mark named schemas as in-progress before generation so recursive graphs + // (for example polymorphic self references) resolve to $ref instead of looping. + if fieldSchema.tag.TypeName != "" { + c.generatedSchemas[fieldSchema.tag.TypeName] = nil + } + schema, err := componentSchema.GenerateSchema(ctx, fieldSchema) + if err != nil { + if fieldSchema.tag.TypeName != "" { + delete(c.generatedSchemas, fieldSchema.tag.TypeName) + } + return nil, err + } + + if fieldSchema.tag.TypeName != "" { + c.generatedSchemas[fieldSchema.tag.TypeName] = schema + c.schemas = append(c.schemas, schema) + schema = c.SchemaRef(fieldSchema.tag.TypeName, description) + } + + return schema, nil +} + +func (c *SchemaContainer) SchemaRef(schemaName string, description string) *openapi3.Schema { + return &openapi3.Schema{ + Ref: "#/components/schemas/" + schemaName, + Description: description, + } +} + +func (c *SchemaContainer) toOpenApiType(rType reflect.Type) (string, string, error) { + apiType, format, ok := c.asOpenApiType(rType) + if !ok { + return empty, empty, fmt.Errorf("unsupported openapi3 type %v", rType.String()) + } + return apiType, format, nil +} + +func (c *SchemaContainer) asOpenApiType(rType reflect.Type) (string, string, bool) { + rType = dereferenceType(rType) + switch rType.Kind() { + case reflect.Int, reflect.Int64, reflect.Uint, reflect.Uint64: + return integerOutput, int64Format, true + case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint8, reflect.Uint16, reflect.Uint32: + return integerOutput, int32Format, true + case reflect.Float64, reflect.Float32: + return numberOutput, doubleFormat, true + case reflect.Bool: + return booleanOutput, empty, true + case reflect.String: + return stringOutput, empty, true + } + + return empty, empty, false +} + +func updatedDocumentation(aTag *Tag, docs *state.Docs, fieldSchema *Schema) { + if docs == nil { + return + } + if aTag.Column != "" && len(docs.Columns) > 0 { + columns := docs.Columns + if aTag.Description == "" { + aTag.Description, _ = columns.ColumnDescription(aTag.Table, aTag.Column) + } + if aTag.Description == "" { + aTag.Description, _ = columns.ColumnDescription("", aTag.Column) + } + if aTag.Example == "" { + aTag.Example, _ = columns.ColumnExample(aTag.Table, aTag.Column) + } + } + if aTag.Description == "" && len(docs.Paths) > 0 { + if desc, ok := docs.Paths.ByName(fieldSchema.path); ok { + aTag.Description = desc + } else if desc, ok := docs.Paths.ByName(fieldSchema.name); ok { + aTag.Description = desc + fieldSchema.description = desc + } + } + if aTag.Description != "" { + fieldSchema.description = aTag.Description + } + if aTag.Example != "" { + fieldSchema.example = aTag.Example + } +} + +func containsAny(format string, values ...string) bool { + for _, value := range values { + if strings.Contains(format, value) { + return true + } + } + return false +} + +func hasInternalColumnTag(v *view.View, table, column string) bool { + if v == nil || column == "" { + return false + } + if matchesViewTable(v, table) { + if cfg := v.ColumnsConfig[column]; cfg != nil && cfg.Tag != nil && strings.Contains(*cfg.Tag, `internal:"true"`) { + return true + } + } + for _, rel := range v.With { + if rel == nil || rel.Of == nil { + continue + } + if hasInternalColumnTag(&rel.Of.View, table, column) { + return true + } + } + return false +} + +func matchesViewTable(v *view.View, table string) bool { + if table == "" { + return true + } + return strings.EqualFold(v.Table, table) || strings.EqualFold(v.Alias, table) || strings.EqualFold(v.Name, table) +} + +func rootTable(component *ComponentSchema) string { + if component.component.View.Mode == view.ModeQuery { + return component.component.View.Table + } + return "" +} + +func dereferenceType(rType reflect.Type) reflect.Type { + for rType.Kind() == reflect.Ptr { + rType = rType.Elem() + } + return rType +} + +func applySchemaExample(dst *openapi3.Schema, schema *Schema) { + if schema.tag.Example != "" { + dst.Example = schema.tag.Example + } +} + +func addTimeSchema(dst *openapi3.Schema, schema *Schema) { + dst.Type = stringOutput + timeLayout := schema.tag._tag.TimeLayout + if timeLayout == "" { + timeLayout = time.RFC3339 + } + if containsAny(timeLayout, "15", "04", "05") { + dst.Format = "date-time" + } else { + dst.Format = "date" + } + if dst.Example == nil { + dst.Example = time.Now().Format(timeLayout) + } + dst.Pattern = ftime.TimeLayoutToDateFormat(timeLayout) +} + +func shouldSkipStructField(field reflect.StructField) bool { + if field.PkgPath != "" { + return true + } + rawTag := string(field.Tag) + return strings.Contains(rawTag, `internal:"true"`) || strings.Contains(rawTag, `json:"-"`) +} + +func normalizeFieldTag(aTag *Tag, fieldName, rootTable, currentTable string) (updatedTable bool) { + if aTag.Table == "" { + aTag.Table = currentTable + } + if aTag.Ignore { + return false + } + if aTag.Column != "" && currentTable == "" { + aTag.Table = rootTable + return true + } + if currentTable != "" && aTag.Column == "" { + aTag.Column = text.DetectCaseFormat(fieldName).To(text.CaseFormatUpperUnderscore).Format(fieldName) + } + return false +} + +func shouldSkipByTag(component *ComponentSchema, aTag *Tag) bool { + if aTag.Ignore { + return true + } + return hasInternalColumnTag(component.component.View, aTag.Table, aTag.Column) || + hasInternalColumnTag(component.component.View, "", aTag.Column) +} + +func implementsInterface(candidateType, interfaceType reflect.Type) bool { + if candidateType.Implements(interfaceType) { + return true + } + if candidateType.Kind() != reflect.Ptr && reflect.PtrTo(candidateType).Implements(interfaceType) { + return true + } + return false +} + +func oneOfDiscriminator(variants openapi3.SchemaList) *openapi3.Discriminator { + mapping := map[string]string{} + for _, variant := range variants { + if variant == nil || variant.Ref == "" { + continue + } + ref := variant.Ref + name := ref[strings.LastIndex(ref, "/")+1:] + if name == "" { + continue + } + mapping[name] = ref + } + if len(mapping) == 0 { + return nil + } + return &openapi3.Discriminator{ + PropertyName: "type", + Mapping: mapping, + } +} + +func (c *SchemaContainer) applyDiscriminatorToVariants(discriminator *openapi3.Discriminator) { + if discriminator == nil || len(discriminator.Mapping) == 0 { + return + } + for value, ref := range discriminator.Mapping { + schemaName := refName(ref) + if schemaName == "" { + continue + } + variant := c.generatedSchemas[schemaName] + if variant == nil || variant.Type != objectOutput { + continue + } + if len(variant.Properties) == 0 { + variant.Properties = openapi3.Schemas{} + } + if variant.Properties[discriminator.PropertyName] == nil { + variant.Properties[discriminator.PropertyName] = &openapi3.Schema{ + Type: stringOutput, + Enum: []interface{}{value}, + } + } + if !containsString(variant.Required, discriminator.PropertyName) { + variant.Required = append(variant.Required, discriminator.PropertyName) + } + } +} + +func refName(ref string) string { + if ref == "" { + return "" + } + index := strings.LastIndex(ref, "/") + if index == -1 || index == len(ref)-1 { + return "" + } + return ref[index+1:] +} + +func containsString(values []string, target string) bool { + for _, item := range values { + if item == target { + return true + } + } + return false +} + +func shouldFailOnPolymorphismSkip() bool { + raw := strings.TrimSpace(strings.ToLower(os.Getenv("DATLY_OPENAPI_POLY_STRICT"))) + return raw == "1" || raw == "true" || raw == "yes" +} diff --git a/gateway/router/openapi/schema_build_helpers_test.go b/gateway/router/openapi/schema_build_helpers_test.go new file mode 100644 index 000000000..0fc1516ef --- /dev/null +++ b/gateway/router/openapi/schema_build_helpers_test.go @@ -0,0 +1,409 @@ +package openapi + +import ( + "context" + "reflect" + "testing" + + "github.com/viant/datly/gateway/router/openapi/openapi3" + "github.com/viant/datly/repository" + "github.com/viant/datly/view" + "github.com/viant/xreflect" +) + +type testAnimal interface { + Kind() string +} + +type testDog struct{} + +func (testDog) Kind() string { return "dog" } + +type testCat struct{} + +func (*testCat) Kind() string { return "cat" } + +type testTree struct{} + +type testUnsupported chan int + +func (testUnsupported) Kind() string { return "unsupported" } + +type recursiveAnimal interface { + Kind() string +} + +type recursiveDog struct { + Child recursiveAnimal `json:"child,omitempty"` +} + +func (recursiveDog) Kind() string { return "dog" } + +type RecursiveEmbed struct { + *RecursiveEmbed +} + +func TestSchemaBuildHelpers_Table(t *testing.T) { + t.Run("apply schema example", func(t *testing.T) { + dst := &openapi3.Schema{} + applySchemaExample(dst, &Schema{tag: Tag{Example: "abc"}}) + if dst.Example != "abc" { + t.Fatalf("expected example to be applied") + } + applySchemaExample(dst, &Schema{}) + if dst.Example != "abc" { + t.Fatalf("expected empty example not to override existing value") + } + }) + + t.Run("root table", func(t *testing.T) { + queryComp := &ComponentSchema{component: &repository.Component{View: &view.View{Mode: view.ModeQuery, Table: "users"}}} + if got := rootTable(queryComp); got != "users" { + t.Fatalf("expected users, got %q", got) + } + nonQueryComp := &ComponentSchema{component: &repository.Component{View: &view.View{Mode: view.Mode("Other"), Table: "users"}}} + if got := rootTable(nonQueryComp); got != "" { + t.Fatalf("expected empty root table for non-query mode") + } + }) + + t.Run("normalize field tag", func(t *testing.T) { + tests := []struct { + name string + tag Tag + rootTable string + table string + wantTable string + updated bool + column string + }{ + {name: "column sets root table", tag: Tag{Column: "ID"}, rootTable: "users", table: "", wantTable: "users", updated: true, column: "ID"}, + {name: "table infers column", tag: Tag{}, rootTable: "", table: "users", wantTable: "users", updated: false, column: "FIRST_NAME"}, + {name: "ignored tag", tag: Tag{Ignore: true}, rootTable: "users", table: "users", wantTable: "users", updated: false, column: ""}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tag := tt.tag + updated := normalizeFieldTag(&tag, "FirstName", tt.rootTable, tt.table) + if updated != tt.updated { + t.Fatalf("expected updated=%v, got %v", tt.updated, updated) + } + if tag.Table != tt.wantTable { + t.Fatalf("expected table %q, got %q", tt.wantTable, tag.Table) + } + if tag.Column != tt.column { + t.Fatalf("expected column %q, got %q", tt.column, tag.Column) + } + }) + } + }) + + t.Run("should skip by tag", func(t *testing.T) { + tag := `internal:"true"` + component := &ComponentSchema{component: &repository.Component{View: &view.View{Table: "users", ColumnsConfig: map[string]*view.ColumnConfig{"ID": {Tag: &tag}}}}} + if !shouldSkipByTag(component, &Tag{Ignore: true}) { + t.Fatalf("expected ignored tag to be skipped") + } + if !shouldSkipByTag(component, &Tag{Table: "users", Column: "ID"}) { + t.Fatalf("expected internal column to be skipped") + } + if shouldSkipByTag(component, &Tag{Ignore: false, Table: "users", Column: "Name"}) { + t.Fatalf("did not expect non-internal column to be skipped") + } + }) + + t.Run("should skip struct field", func(t *testing.T) { + type sample struct { + exported string + Visible string + Hidden string `json:"-"` + Internal string `internal:"true"` + } + rType := reflect.TypeOf(sample{}) + if !shouldSkipStructField(rType.Field(0)) { + t.Fatalf("expected unexported field to be skipped") + } + if shouldSkipStructField(rType.Field(1)) { + t.Fatalf("did not expect visible field to be skipped") + } + if !shouldSkipStructField(rType.Field(2)) { + t.Fatalf("expected json:- field to be skipped") + } + if !shouldSkipStructField(rType.Field(3)) { + t.Fatalf("expected internal:true field to be skipped") + } + }) + + t.Run("add time schema default and pre-existing example", func(t *testing.T) { + dst := &openapi3.Schema{} + addTimeSchema(dst, &Schema{}) + if dst.Type != stringOutput || dst.Format != "date-time" || dst.Pattern == "" { + t.Fatalf("unexpected default time schema: type=%s format=%s pattern=%s", dst.Type, dst.Format, dst.Pattern) + } + existing := &openapi3.Schema{Example: "preset"} + addTimeSchema(existing, &Schema{}) + if existing.Example != "preset" { + t.Fatalf("expected existing example to be preserved") + } + }) + + t.Run("interface oneOf scaffolding", func(t *testing.T) { + t.Setenv("DATLY_OPENAPI_POLY_STRICT", "false") + component := newTestComponent(t) + types := xreflect.NewTypes() + if err := types.Register("Animal", xreflect.WithPackage("test"), xreflect.WithReflectType(reflect.TypeOf((*testAnimal)(nil)).Elem())); err != nil { + t.Fatalf("register interface failed: %v", err) + } + if err := types.Register("Dog", xreflect.WithPackage("test"), xreflect.WithReflectType(reflect.TypeOf(testDog{}))); err != nil { + t.Fatalf("register dog failed: %v", err) + } + if err := types.Register("Cat", xreflect.WithPackage("test"), xreflect.WithReflectType(reflect.TypeOf(testCat{}))); err != nil { + t.Fatalf("register cat failed: %v", err) + } + if err := types.Register("DogAlias", xreflect.WithPackage("test"), xreflect.WithReflectType(reflect.TypeOf(testDog{}))); err != nil { + t.Fatalf("register dog alias failed: %v", err) + } + if err := types.Register("Tree", xreflect.WithPackage("test"), xreflect.WithReflectType(reflect.TypeOf(testTree{}))); err != nil { + t.Fatalf("register tree failed: %v", err) + } + if err := types.Register("Unsupported", xreflect.WithPackage("test"), xreflect.WithReflectType(reflect.TypeOf(testUnsupported(nil)))); err != nil { + t.Fatalf("register unsupported failed: %v", err) + } + setUnexportedField(component, "types", types) + + container := NewContainer() + componentSchema := &ComponentSchema{component: component, schemas: container} + dst := &openapi3.Schema{} + err := container.addToSchema(context.Background(), componentSchema, dst, &Schema{ + rType: reflect.TypeOf((*testAnimal)(nil)).Elem(), + ioConfig: component.IOConfig(), + }) + if err != nil { + t.Fatalf("unexpected addToSchema error: %v", err) + } + if dst.Type != objectOutput { + t.Fatalf("expected object type for interface, got %q", dst.Type) + } + if len(dst.OneOf) != 2 { + t.Fatalf("expected oneOf variants for interface, got %d", len(dst.OneOf)) + } + if dst.Discriminator == nil { + t.Fatalf("expected discriminator to be set for oneOf interface schema") + } + if dst.Discriminator.PropertyName != "type" { + t.Fatalf("expected discriminator propertyName type, got %q", dst.Discriminator.PropertyName) + } + if len(dst.Discriminator.Mapping) != 2 { + t.Fatalf("expected discriminator mapping entries, got %d", len(dst.Discriminator.Mapping)) + } + if dst.Discriminator.Mapping["Dog"] != "#/components/schemas/Dog" { + t.Fatalf("unexpected discriminator mapping for Dog: %q", dst.Discriminator.Mapping["Dog"]) + } + if dst.Discriminator.Mapping["Cat"] != "#/components/schemas/Cat" { + t.Fatalf("unexpected discriminator mapping for Cat: %q", dst.Discriminator.Mapping["Cat"]) + } + + dogSchema := container.generatedSchemas["Dog"] + if dogSchema == nil || dogSchema.Properties["type"] == nil { + t.Fatalf("expected discriminator property injected in Dog schema") + } + if !containsString(dogSchema.Required, "type") { + t.Fatalf("expected discriminator property required in Dog schema") + } + + if dst.Extension == nil { + t.Fatalf("expected best-effort extension metadata") + } + skipped, ok := dst.Extension["x-datly-polymorphism-skipped"].([]string) + if !ok || len(skipped) == 0 { + t.Fatalf("expected skipped implementors extension") + } + }) + + t.Run("interface oneOf fallback without registry", func(t *testing.T) { + component := &repository.Component{} + container := NewContainer() + componentSchema := &ComponentSchema{component: component, schemas: container} + dst := &openapi3.Schema{} + err := container.addToSchema(context.Background(), componentSchema, dst, &Schema{ + rType: reflect.TypeOf((*testAnimal)(nil)).Elem(), + ioConfig: component.IOConfig(), + }) + if err != nil { + t.Fatalf("unexpected addToSchema error: %v", err) + } + if dst.Type != objectOutput { + t.Fatalf("expected object type for interface fallback, got %q", dst.Type) + } + if len(dst.OneOf) != 0 { + t.Fatalf("expected no oneOf variants when registry is unavailable, got %d", len(dst.OneOf)) + } + if dst.Discriminator != nil { + t.Fatalf("expected no discriminator without variants") + } + }) + + t.Run("implements interface", func(t *testing.T) { + tests := []struct { + name string + candidate reflect.Type + want bool + }{ + {name: "value receiver", candidate: reflect.TypeOf(testDog{}), want: true}, + {name: "pointer receiver", candidate: reflect.TypeOf(testCat{}), want: true}, + {name: "not implementor", candidate: reflect.TypeOf(struct{}{}), want: false}, + } + iface := reflect.TypeOf((*testAnimal)(nil)).Elem() + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := implementsInterface(tt.candidate, iface); got != tt.want { + t.Fatalf("expected %v, got %v", tt.want, got) + } + }) + } + }) + + t.Run("interface variants nil component", func(t *testing.T) { + container := NewContainer() + variants, skipped, err := container.interfaceVariants(context.Background(), nil, &Schema{}, reflect.TypeOf((*testAnimal)(nil)).Elem()) + if err != nil { + t.Fatalf("unexpected interfaceVariants error: %v", err) + } + if len(variants) != 0 { + t.Fatalf("expected no variants for nil component, got %d", len(variants)) + } + if len(skipped) != 0 { + t.Fatalf("expected no skipped variants for nil component") + } + }) + + t.Run("interface oneOf strict mode", func(t *testing.T) { + t.Setenv("DATLY_OPENAPI_POLY_STRICT", "true") + component := newTestComponent(t) + types := xreflect.NewTypes() + if err := types.Register("Animal", xreflect.WithPackage("test"), xreflect.WithReflectType(reflect.TypeOf((*testAnimal)(nil)).Elem())); err != nil { + t.Fatalf("register interface failed: %v", err) + } + if err := types.Register("Dog", xreflect.WithPackage("test"), xreflect.WithReflectType(reflect.TypeOf(testDog{}))); err != nil { + t.Fatalf("register dog failed: %v", err) + } + if err := types.Register("Unsupported", xreflect.WithPackage("test"), xreflect.WithReflectType(reflect.TypeOf(testUnsupported(nil)))); err != nil { + t.Fatalf("register unsupported failed: %v", err) + } + setUnexportedField(component, "types", types) + + container := NewContainer() + componentSchema := &ComponentSchema{component: component, schemas: container} + dst := &openapi3.Schema{} + err := container.addToSchema(context.Background(), componentSchema, dst, &Schema{ + rType: reflect.TypeOf((*testAnimal)(nil)).Elem(), + ioConfig: component.IOConfig(), + }) + if err == nil { + t.Fatalf("expected strict mode polymorphism error") + } + }) + + t.Run("oneOf discriminator and helper branches", func(t *testing.T) { + t.Run("empty refs yield nil discriminator", func(t *testing.T) { + discriminator := oneOfDiscriminator(openapi3.SchemaList{{Type: objectOutput}, nil}) + if discriminator != nil { + t.Fatalf("expected nil discriminator when refs are absent") + } + }) + + t.Run("apply discriminator skips non-object and missing schema", func(t *testing.T) { + container := NewContainer() + container.generatedSchemas["User"] = &openapi3.Schema{Type: objectOutput} + container.generatedSchemas["Arr"] = &openapi3.Schema{Type: arrayOutput} + container.applyDiscriminatorToVariants(&openapi3.Discriminator{ + PropertyName: "kind", + Mapping: map[string]string{ + "user": "#/components/schemas/User", + "arr": "#/components/schemas/Arr", + "miss": "#/components/schemas/Missing", + }, + }) + user := container.generatedSchemas["User"] + if user == nil || user.Properties["kind"] == nil { + t.Fatalf("expected discriminator property on object variant") + } + if !containsString(user.Required, "kind") { + t.Fatalf("expected discriminator property required on object variant") + } + arr := container.generatedSchemas["Arr"] + if arr != nil && arr.Properties != nil { + if _, ok := arr.Properties["kind"]; ok { + t.Fatalf("did not expect discriminator property on non-object variant") + } + } + }) + + t.Run("refName invalid variants", func(t *testing.T) { + if got := refName(""); got != "" { + t.Fatalf("expected empty ref name for empty ref") + } + if got := refName("abc"); got != "" { + t.Fatalf("expected empty ref name for malformed ref") + } + if got := refName("abc/"); got != "" { + t.Fatalf("expected empty ref name for trailing slash") + } + }) + }) + + t.Run("recursive polymorphic graph does not loop", func(t *testing.T) { + t.Setenv("DATLY_OPENAPI_POLY_STRICT", "false") + component := newTestComponent(t) + types := xreflect.NewTypes() + if err := types.Register("RecursiveAnimal", xreflect.WithPackage("test"), xreflect.WithReflectType(reflect.TypeOf((*recursiveAnimal)(nil)).Elem())); err != nil { + t.Fatalf("register interface failed: %v", err) + } + if err := types.Register("RecursiveDog", xreflect.WithPackage("test"), xreflect.WithReflectType(reflect.TypeOf(recursiveDog{}))); err != nil { + t.Fatalf("register struct failed: %v", err) + } + setUnexportedField(component, "types", types) + + container := NewContainer() + componentSchema := &ComponentSchema{component: component, schemas: container} + dst := &openapi3.Schema{} + err := container.addToSchema(context.Background(), componentSchema, dst, &Schema{ + rType: reflect.TypeOf((*recursiveAnimal)(nil)).Elem(), + ioConfig: component.IOConfig(), + }) + if err != nil { + t.Fatalf("unexpected addToSchema error: %v", err) + } + if len(dst.OneOf) == 0 { + t.Fatalf("expected oneOf variants for recursive interface") + } + dog := container.generatedSchemas["RecursiveDog"] + if dog == nil { + t.Fatalf("expected RecursiveDog schema to be generated") + } + child := dog.Properties["child"] + if child == nil { + t.Fatalf("expected recursive child schema") + } + if child.Ref == "" && len(child.OneOf) == 0 { + t.Fatalf("expected recursive child to be represented as ref or oneOf") + } + }) + + t.Run("anonymous self embed does not recurse indefinitely", func(t *testing.T) { + component := newTestComponent(t) + container := NewContainer() + componentSchema := &ComponentSchema{component: component, schemas: container} + dst := &openapi3.Schema{} + err := container.addToSchema(context.Background(), componentSchema, dst, &Schema{ + rType: reflect.TypeOf(RecursiveEmbed{}), + ioConfig: component.IOConfig(), + }) + if err != nil { + t.Fatalf("unexpected addToSchema error: %v", err) + } + if dst.Type != objectOutput { + t.Fatalf("expected object type for recursive embed, got %q", dst.Type) + } + }) +} diff --git a/gateway/router/openapi/schema_helpers_test.go b/gateway/router/openapi/schema_helpers_test.go new file mode 100644 index 000000000..ddde60523 --- /dev/null +++ b/gateway/router/openapi/schema_helpers_test.go @@ -0,0 +1,329 @@ +package openapi + +import ( + "context" + "reflect" + "testing" + + "github.com/viant/datly/gateway/router/openapi/openapi3" + "github.com/viant/datly/repository" + "github.com/viant/datly/view" + "github.com/viant/datly/view/state" +) + +type sampleNested struct { + ID int +} + +type sampleWithField struct { + UserName string `json:"user_name" desc:"user name desc" example:"bob"` +} + +func TestSchemaSliceItem(t *testing.T) { + tests := []struct { + name string + typeName string + rType reflect.Type + expectType reflect.Type + expectSchema string + }{ + {name: "named element", typeName: "Entry", rType: reflect.TypeOf([]sampleNested{}), expectType: reflect.TypeOf(sampleNested{}), expectSchema: "sampleNested"}, + {name: "anonymous element", typeName: "Entry", rType: reflect.TypeOf([]struct{ Value int }{}), expectType: reflect.TypeOf(struct{ Value int }{}), expectSchema: "EntryItem"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + s := &Schema{tag: Tag{TypeName: tt.typeName}} + item := s.SliceItem(tt.rType) + if item.rType != tt.expectType { + t.Fatalf("expected %v, got %v", tt.expectType, item.rType) + } + if item.tag.TypeName != tt.expectSchema { + t.Fatalf("expected %q, got %q", tt.expectSchema, item.tag.TypeName) + } + }) + } +} + +func TestSchemaField(t *testing.T) { + component := &repository.Component{} + rType := reflect.TypeOf(sampleWithField{}) + field := rType.Field(0) + + tests := []struct { + name string + tag *Tag + expectField string + expectDesc string + expectExample string + }{ + {name: "uses json name", tag: &Tag{JSONName: "custom_name"}, expectField: "custom_name", expectDesc: "user name desc", expectExample: "bob"}, + {name: "falls back to formatted name", tag: &Tag{}, expectField: "UserName", expectDesc: "user name desc", expectExample: "bob"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + s := &Schema{ioConfig: component.IOConfig()} + got, err := s.Field(field, tt.tag) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if got.fieldName != tt.expectField { + t.Fatalf("expected field %q, got %q", tt.expectField, got.fieldName) + } + if got.description != tt.expectDesc { + t.Fatalf("expected description %q, got %q", tt.expectDesc, got.description) + } + if got.example != tt.expectExample { + t.Fatalf("expected example %q, got %q", tt.expectExample, got.example) + } + }) + } +} + +func TestContainsAny(t *testing.T) { + tests := []struct { + name string + format string + values []string + expect bool + }{ + {name: "contains", format: "2006-01-02T15:04:05", values: []string{"15", "04"}, expect: true}, + {name: "not contains", format: "2006-01-02", values: []string{"15", "04", "05"}, expect: false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := containsAny(tt.format, tt.values...) + if got != tt.expect { + t.Fatalf("expected %v, got %v", tt.expect, got) + } + }) + } +} + +func TestAsOpenAPIType(t *testing.T) { + container := NewContainer() + tests := []struct { + name string + rType reflect.Type + api string + format string + ok bool + }{ + {name: "int64", rType: reflect.TypeOf(int64(1)), api: integerOutput, format: int64Format, ok: true}, + {name: "uint32", rType: reflect.TypeOf(uint32(1)), api: integerOutput, format: int32Format, ok: true}, + {name: "float64", rType: reflect.TypeOf(float64(1)), api: numberOutput, format: doubleFormat, ok: true}, + {name: "bool", rType: reflect.TypeOf(true), api: booleanOutput, format: empty, ok: true}, + {name: "string", rType: reflect.TypeOf(""), api: stringOutput, format: empty, ok: true}, + {name: "ptr", rType: reflect.TypeOf(new(int)), api: integerOutput, format: int64Format, ok: true}, + {name: "unsupported struct", rType: reflect.TypeOf(struct{}{}), api: empty, format: empty, ok: false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + api, format, ok := container.asOpenApiType(tt.rType) + if ok != tt.ok { + t.Fatalf("expected ok=%v, got %v", tt.ok, ok) + } + if api != tt.api || format != tt.format { + t.Fatalf("expected %s/%s, got %s/%s", tt.api, tt.format, api, format) + } + }) + } +} + +func TestToOpenApiType(t *testing.T) { + container := NewContainer() + tests := []struct { + name string + rType reflect.Type + wantError bool + }{ + {name: "supported", rType: reflect.TypeOf(int(1)), wantError: false}, + {name: "unsupported", rType: reflect.TypeOf(struct{}{}), wantError: true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, _, err := container.toOpenApiType(tt.rType) + if (err != nil) != tt.wantError { + t.Fatalf("wantError=%v got err=%v", tt.wantError, err) + } + }) + } +} + +func TestSchemaRef(t *testing.T) { + container := NewContainer() + tests := []struct { + name string + schemaName string + description string + expectRef string + }{ + {name: "basic", schemaName: "MyType", description: "desc", expectRef: "#/components/schemas/MyType"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := container.SchemaRef(tt.schemaName, tt.description) + if got.Ref != tt.expectRef { + t.Fatalf("expected ref %q, got %q", tt.expectRef, got.Ref) + } + if got.Description != tt.description { + t.Fatalf("expected description %q, got %q", tt.description, got.Description) + } + }) + } +} + +func TestUpdatedDocumentation(t *testing.T) { + tests := []struct { + name string + tag *Tag + docs *state.Docs + field *Schema + expectDesc string + expectExample string + }{ + { + name: "column docs", + tag: &Tag{Table: "users", Column: "name"}, + docs: &state.Docs{Columns: state.Documentation{"users.name": "column desc", "users.name$example": "alice"}}, + field: &Schema{path: "pkg.User.Name", name: "Name"}, + expectDesc: "column desc", expectExample: "alice", + }, + { + name: "path docs fallback", + tag: &Tag{}, + docs: &state.Docs{Paths: state.Documentation{"pkg.User.Name": "path desc"}}, + field: &Schema{path: "pkg.User.Name", name: "Name"}, + expectDesc: "path desc", expectExample: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + updatedDocumentation(tt.tag, tt.docs, tt.field) + if tt.field.description != tt.expectDesc { + t.Fatalf("expected description %q, got %q", tt.expectDesc, tt.field.description) + } + if tt.field.example != tt.expectExample { + t.Fatalf("expected example %q, got %q", tt.expectExample, tt.field.example) + } + }) + } +} + +func TestMatchesViewTable(t *testing.T) { + v := &view.View{Table: "users", Alias: "u", Name: "UsersView"} + tests := []struct { + name string + table string + expect bool + }{ + {name: "table", table: "users", expect: true}, + {name: "alias", table: "u", expect: true}, + {name: "name", table: "UsersView", expect: true}, + {name: "empty", table: "", expect: true}, + {name: "miss", table: "products", expect: false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := matchesViewTable(v, tt.table); got != tt.expect { + t.Fatalf("expected %v, got %v", tt.expect, got) + } + }) + } +} + +func TestHasInternalColumnTag(t *testing.T) { + tag := `internal:"true"` + relTag := `internal:"true"` + v := &view.View{ + Table: "users", + ColumnsConfig: map[string]*view.ColumnConfig{ + "ID": {Tag: &tag}, + }, + With: []*view.Relation{ + {Of: &view.ReferenceView{View: view.View{Table: "orders", ColumnsConfig: map[string]*view.ColumnConfig{"OrderID": {Tag: &relTag}}}}}, + }, + } + + tests := []struct { + name string + view *view.View + table string + column string + expect bool + }{ + {name: "nil view", view: nil, table: "users", column: "ID", expect: false}, + {name: "empty column", view: v, table: "users", column: "", expect: false}, + {name: "current view", view: v, table: "users", column: "ID", expect: true}, + {name: "relation", view: v, table: "orders", column: "OrderID", expect: true}, + {name: "not internal", view: v, table: "users", column: "Name", expect: false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := hasInternalColumnTag(tt.view, tt.table, tt.column); got != tt.expect { + t.Fatalf("expected %v, got %v", tt.expect, got) + } + }) + } +} + +func TestAddToSchemaSimpleBranches(t *testing.T) { + container := NewContainer() + component := &ComponentSchema{component: &repository.Component{View: &view.View{}}, schemas: container} + type mapRecord struct { + ID int `json:"id"` + } + tests := []struct { + name string + rType reflect.Type + expectType string + expectPropsLen int + expectAdditionalType string + expectAdditionalItemRef bool + }{ + {name: "interface", rType: reflect.TypeOf((*interface{})(nil)).Elem(), expectType: objectOutput, expectPropsLen: 0}, + {name: "map primitive value", rType: reflect.TypeOf(map[string]int{}), expectType: objectOutput, expectPropsLen: 0, expectAdditionalType: integerOutput}, + {name: "map array value", rType: reflect.TypeOf(map[string][]string{}), expectType: objectOutput, expectPropsLen: 0, expectAdditionalType: arrayOutput}, + {name: "map object value", rType: reflect.TypeOf(map[string]mapRecord{}), expectType: objectOutput, expectPropsLen: 0, expectAdditionalItemRef: true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + dst := &openapi3.Schema{} + err := container.addToSchema(context.Background(), component, dst, &Schema{rType: tt.rType, ioConfig: component.component.IOConfig()}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if dst.Type != tt.expectType { + t.Fatalf("expected type %q, got %q", tt.expectType, dst.Type) + } + if len(dst.Properties) != tt.expectPropsLen { + t.Fatalf("expected %d properties, got %d", tt.expectPropsLen, len(dst.Properties)) + } + if tt.expectAdditionalType != "" { + if dst.AdditionalProperties == nil { + t.Fatalf("expected additionalProperties schema") + } + if dst.AdditionalProperties.Type != tt.expectAdditionalType { + t.Fatalf("expected additionalProperties type %q, got %q", tt.expectAdditionalType, dst.AdditionalProperties.Type) + } + } + if tt.expectAdditionalItemRef { + if dst.AdditionalProperties == nil { + t.Fatalf("expected additionalProperties schema") + } + if dst.AdditionalProperties.Ref == "" { + t.Fatalf("expected additionalProperties to reference a schema") + } + } + }) + } +} diff --git a/gateway/router/openapi/tag.go b/gateway/router/openapi/tag.go index 7c144f337..9a1f51549 100644 --- a/gateway/router/openapi/tag.go +++ b/gateway/router/openapi/tag.go @@ -35,6 +35,7 @@ type ( _tag format.Tag TypeName string Parameter *tags.Parameter + IsInput bool Column string Table string } @@ -78,6 +79,12 @@ func ParseTag(field reflect.StructField, tag reflect.StructTag, isInput bool, ro Example: tag.Get(tags.ExampleTag), JSONName: jsonName, _tag: *aTag, + IsInput: isInput, + } + + // Keep internal runtime-only fields out of OpenAPI schema. + if tag.Get("internal") == "true" { + ret.Ignore = true } if tags, _ := tags.Parse(tag, nil, tags.ParameterTag); tags != nil { diff --git a/gateway/router/openapi/tag_parse_test.go b/gateway/router/openapi/tag_parse_test.go new file mode 100644 index 000000000..218387fe6 --- /dev/null +++ b/gateway/router/openapi/tag_parse_test.go @@ -0,0 +1,73 @@ +package openapi + +import ( + "reflect" + "testing" +) + +type tagParseNamed struct { + A int +} + +type tagParseFixture struct { + Values []int `json:"values"` + Any tagParseNamed `json:"any"` + Hidden string `json:"hidden" internal:"true"` + Summary string `json:"summary" parameter:"kind=output,in=summary"` + ViewOut string `json:"view_out" parameter:"kind=output,in=view"` + InputDrop string `json:"input_drop" parameter:"kind=query,in=id"` + ByViewTable string `json:"by_view" view:"name=V,table=orders"` + BySQLX string `json:"by_sqlx" sqlx:"name=ORD_ID"` +} + +func TestParseTag(t *testing.T) { + rType := reflect.TypeOf(tagParseFixture{}) + tests := []struct { + name string + fieldIndex int + isInput bool + rootTable string + expectIgnore bool + expectTypeName string + expectJSONName string + expectNullable bool + expectTableValue string + }{ + {name: "slice sets json name", fieldIndex: 0, isInput: false, rootTable: "", expectIgnore: false, expectTypeName: "", expectJSONName: "values", expectNullable: false}, + {name: "struct sets type name", fieldIndex: 1, isInput: false, rootTable: "", expectIgnore: false, expectTypeName: "openapi.tagParseNamed", expectJSONName: "any", expectNullable: false}, + {name: "internal flag ignored", fieldIndex: 2, isInput: false, rootTable: "", expectIgnore: true, expectTypeName: "", expectJSONName: "hidden", expectNullable: false}, + {name: "output summary table", fieldIndex: 3, isInput: false, rootTable: "root", expectIgnore: false, expectTypeName: "", expectJSONName: "summary", expectNullable: false, expectTableValue: "SUMMARY"}, + {name: "output view table", fieldIndex: 4, isInput: false, rootTable: "root", expectIgnore: false, expectTypeName: "", expectJSONName: "view_out", expectNullable: false, expectTableValue: "root"}, + {name: "input non body ignored", fieldIndex: 5, isInput: true, rootTable: "root", expectIgnore: true, expectTypeName: "", expectJSONName: "input_drop", expectNullable: false}, + {name: "view tag table", fieldIndex: 6, isInput: false, rootTable: "root", expectIgnore: false, expectTypeName: "", expectJSONName: "by_view", expectNullable: false, expectTableValue: "orders"}, + {name: "sqlx column captured", fieldIndex: 7, isInput: false, rootTable: "root", expectIgnore: false, expectTypeName: "", expectJSONName: "by_sqlx", expectNullable: false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + field := rType.Field(tt.fieldIndex) + parsed, err := ParseTag(field, field.Tag, tt.isInput, tt.rootTable) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if parsed.Ignore != tt.expectIgnore { + t.Fatalf("expected ignore %v, got %v", tt.expectIgnore, parsed.Ignore) + } + if parsed.TypeName != tt.expectTypeName { + t.Fatalf("expected type name %q, got %q", tt.expectTypeName, parsed.TypeName) + } + if parsed.JSONName != tt.expectJSONName { + t.Fatalf("expected json name %q, got %q", tt.expectJSONName, parsed.JSONName) + } + if parsed.IsNullable != tt.expectNullable { + t.Fatalf("expected nullable %v, got %v", tt.expectNullable, parsed.IsNullable) + } + if parsed.Table != tt.expectTableValue { + t.Fatalf("expected table %q, got %q", tt.expectTableValue, parsed.Table) + } + if tt.name == "sqlx column captured" && parsed.Column != "ORD_ID" { + t.Fatalf("expected column ORD_ID, got %q", parsed.Column) + } + }) + } +} diff --git a/gateway/router/status/error.go b/gateway/router/status/error.go index 3d9111abc..41087fe5e 100644 --- a/gateway/router/status/error.go +++ b/gateway/router/status/error.go @@ -1,9 +1,12 @@ package status import ( + "errors" + "net/http" + "github.com/viant/datly/service/executor/expand" + derrors "github.com/viant/datly/utils/errors" "github.com/viant/datly/utils/httputils" - "github.com/viant/datly/utils/types" "github.com/viant/govalidator" svalidator "github.com/viant/sqlx/io/validator" "github.com/viant/xdatly/handler/response" @@ -13,31 +16,103 @@ func NormalizeErr(err error, statusCode int) (int, string, interface{}) { violations := httputils.Violations{} switch actual := err.(type) { case *response.Error: - return actual.StatusCode(), actual.Message, nil + if derrors.IsDatabaseError(actual.Err) || derrors.IsDatabaseError(errors.New(actual.Message)) { + actual.Code = http.StatusInternalServerError + actual.Message = http.StatusText(http.StatusInternalServerError) + return http.StatusInternalServerError, http.StatusText(http.StatusInternalServerError), nil + } + code := actual.StatusCode() + if code == 0 { + code = statusCode + } + if code == 0 { + code = http.StatusBadRequest + } + // For explicit 5xx we keep response generic, for 4xx we trust the configured message. + if code >= http.StatusInternalServerError { + return code, http.StatusText(http.StatusInternalServerError), nil + } + return code, actual.Message, nil case *svalidator.Validation: ret := violations.MergeSqlViolation(actual.Violations) - return statusCode, err.Error(), ret + return http.StatusBadRequest, err.Error(), ret case *govalidator.Validation: ret := violations.MergeGoViolation(actual.Violations) - return statusCode, actual.Error(), ret + return http.StatusBadRequest, actual.Error(), ret case *response.Errors: - actual.SetStatusCode(statusCode) + maxStatus := actual.StatusCode() + if maxStatus == 0 { + maxStatus = statusCode + } + if maxStatus == 0 { + maxStatus = http.StatusBadRequest + } + hasServerError := maxStatus >= http.StatusInternalServerError || derrors.IsDatabaseError(errors.New(actual.Message)) + for _, anError := range actual.Errors { - isObj := types.IsObject(anError.Err) - if isObj { - statusCode, anError.Message, anError.Object = NormalizeErr(anError.Err, statusCode) - } else { - statusCode, anError.Message, anError.Object = NormalizeErr(anError.Err, statusCode) + if derrors.IsDatabaseError(anError.Err) || derrors.IsDatabaseError(errors.New(anError.Message)) { + anError.Code = http.StatusInternalServerError + anError.Message = http.StatusText(http.StatusInternalServerError) + hasServerError = true + } + + code := anError.StatusCode() + switch { + case code >= http.StatusInternalServerError: + anError.Message = http.StatusText(http.StatusInternalServerError) + hasServerError = true + case code == 0: + innerStatus, innerMsg, innerObj := NormalizeErr(anError.Err, maxStatus) + code = innerStatus + anError.Code = innerStatus + if innerMsg != "" { + anError.Message = innerMsg + } + if innerObj != nil { + anError.Object = innerObj + } + if code >= http.StatusInternalServerError { + hasServerError = true + } + default: + if code >= http.StatusInternalServerError { + hasServerError = true + } } + + if code > maxStatus { + maxStatus = code + } + } + + if hasServerError { + actual.Message = http.StatusText(http.StatusInternalServerError) + } else if actual.Message == "" && len(actual.Errors) > 0 { + actual.Message = actual.Errors[0].Message + } + + if maxStatus == 0 { + maxStatus = http.StatusBadRequest } - actual.SetStatusCode(statusCode) - return actual.StatusCode(), actual.Message, actual.Errors + + return maxStatus, actual.Message, actual.Errors case *expand.ErrorResponse: if actual.StatusCode != 0 { statusCode = actual.StatusCode } + // If no status code was set on the error response, treat it as a client error. + if statusCode == 0 { + statusCode = http.StatusBadRequest + } return statusCode, actual.Message, actual.Content default: + // Only DB-caused errors are mapped to 500 with a generic message. + if derrors.IsDatabaseError(err) { + return http.StatusInternalServerError, http.StatusText(http.StatusInternalServerError), nil + } + if statusCode == 0 { + statusCode = http.StatusBadRequest + } return statusCode, err.Error(), nil } } diff --git a/gateway/runtime/apigw/deploy.yaml b/gateway/runtime/apigw/deploy.yaml index 5a7d85dc6..0511616b0 100644 --- a/gateway/runtime/apigw/deploy.yaml +++ b/gateway/runtime/apigw/deploy.yaml @@ -25,7 +25,7 @@ pipeline: set_sdk: action: sdk.set target: $target - sdk: go:1.21 + sdk: go:1.25.1 build: package: diff --git a/gateway/runtime/apigw/handler.go b/gateway/runtime/apigw/handler.go index ccc7858d4..8c9748482 100644 --- a/gateway/runtime/apigw/handler.go +++ b/gateway/runtime/apigw/handler.go @@ -5,7 +5,6 @@ import ( "github.com/aws/aws-lambda-go/events" "github.com/viant/datly/gateway/runtime/serverless" "net/http" - "time" "github.com/viant/datly/gateway/router/proxy" "github.com/viant/datly/gateway/runtime/apigw/adapter" diff --git a/gateway/runtime/gcr/deploy.yaml b/gateway/runtime/gcr/deploy.yaml index 0257ecf68..7cc921e0d 100644 --- a/gateway/runtime/gcr/deploy.yaml +++ b/gateway/runtime/gcr/deploy.yaml @@ -18,7 +18,7 @@ pipeline: setSdk: action: sdk.set target: $target - sdk: go:1.21 + sdk: go:1.25.1 deploy: buildBinary: diff --git a/gateway/runtime/lambda/deploy.yaml b/gateway/runtime/lambda/deploy.yaml index 6a7fbae63..0ea83df8d 100644 --- a/gateway/runtime/lambda/deploy.yaml +++ b/gateway/runtime/lambda/deploy.yaml @@ -27,7 +27,7 @@ pipeline: set_sdk: action: sdk.set target: $target - sdk: go:1.21 + sdk: go:1.25.1 build: package: diff --git a/gateway/runtime/lambda/handler.go b/gateway/runtime/lambda/handler.go index eeeb163d6..3242af061 100644 --- a/gateway/runtime/lambda/handler.go +++ b/gateway/runtime/lambda/handler.go @@ -7,7 +7,6 @@ import ( "github.com/viant/datly/gateway/runtime/lambda/adapter" "github.com/viant/datly/gateway/runtime/serverless" "net/http" - "time" ) func HandleRequest(ctx context.Context, request *adapter.Request) (*events.LambdaFunctionURLResponse, error) { diff --git a/gateway/runtime/standalone/config.go b/gateway/runtime/standalone/config.go index 9318caa88..a6c769521 100644 --- a/gateway/runtime/standalone/config.go +++ b/gateway/runtime/standalone/config.go @@ -4,11 +4,13 @@ import ( "context" "encoding/json" "github.com/viant/afs" + "github.com/viant/afs/url" "github.com/viant/datly/gateway" "github.com/viant/datly/gateway/router/openapi/openapi3" "github.com/viant/datly/gateway/runtime/standalone/endpoint" "github.com/viant/toolbox" "gopkg.in/yaml.v3" + "path/filepath" "strings" ) @@ -71,5 +73,35 @@ func NewConfigFromURL(ctx context.Context, URL string) (*Config, error) { } cfg.URL = URL cfg.Init(ctx) + cfg.normalizeURLs(baseDir(URL)) return cfg, cfg.Validate() } + +func (c *Config) normalizeURLs(baseURL string) { + if url.IsRelative(c.RouteURL) { + c.RouteURL = url.Join(baseURL, c.RouteURL) + } + if url.IsRelative(c.ContentURL) { + c.ContentURL = url.Join(baseURL, c.ContentURL) + } + if url.IsRelative(c.PluginsURL) { + c.PluginsURL = url.Join(baseURL, c.PluginsURL) + } + if url.IsRelative(c.DependencyURL) { + c.DependencyURL = url.Join(baseURL, c.DependencyURL) + } + if c.JobURL != "" && url.IsRelative(c.JobURL) { + c.JobURL = url.Join(baseURL, c.JobURL) + } + if c.FailedJobURL != "" && url.IsRelative(c.FailedJobURL) { + c.FailedJobURL = url.Join(baseURL, c.FailedJobURL) + } +} + +func baseDir(URL string) string { + if strings.Contains(URL, "://") { + parent, _ := url.Split(URL, "file") + return parent + } + return filepath.Dir(URL) +} diff --git a/gateway/service.go b/gateway/service.go index 3b91e519f..794c2b2d7 100644 --- a/gateway/service.go +++ b/gateway/service.go @@ -119,6 +119,9 @@ func New(ctx context.Context, opts ...Option) (*Service, error) { return nil, fmt.Errorf("failed to initialise component service: %w", err) } } + if err = (&Service{Config: aConfig}).applyDQLBootstrap(ctx, componentRepository, aConfig.DQLBootstrap); err != nil { + return nil, fmt.Errorf("failed to apply DQL bootstrap: %w", err) + } var mcpRegistry *serverproto.Registry if aConfig.MCP != nil { @@ -221,6 +224,17 @@ func (r *Service) syncChanges(ctx context.Context, metrics *gmetric.Service, sta return err } r.mux.Lock() + newCount := len(mainRouter.paths) + oldCount := 0 + if r.mainRouter != nil { + oldCount = len(r.mainRouter.paths) + } + if newCount < oldCount { + r.mux.Unlock() + fmt.Printf("[INFO]: routers rebuild skipped (new config has %d routes vs %d existing, keeping existing)\n", newCount, oldCount) + return nil + } + fmt.Printf("[INFO]: router replacing old(%d routes) with new(%d routes)\n", oldCount, newCount) r.mainRouter = mainRouter r.mux.Unlock() fmt.Printf("[INFO]: routers rebuild completed after: %s\n", time.Since(start)) diff --git a/go.mod b/go.mod index 89bb2f2dd..be6c46a28 100644 --- a/go.mod +++ b/go.mod @@ -1,13 +1,12 @@ module github.com/viant/datly -go 1.23.8 +go 1.25.0 require ( github.com/aerospike/aerospike-client-go v4.5.2+incompatible github.com/aws/aws-lambda-go v1.31.0 github.com/francoispqt/gojay v1.2.13 github.com/go-sql-driver/mysql v1.7.0 - github.com/goccy/go-json v0.10.5 github.com/golang-jwt/jwt/v4 v4.5.1 // indirect github.com/google/gops v0.3.23 github.com/google/uuid v1.6.0 @@ -15,9 +14,9 @@ require ( github.com/lib/pq v1.10.6 github.com/mattn/go-sqlite3 v1.14.16 github.com/pkg/errors v0.9.1 - github.com/stretchr/testify v1.10.0 - github.com/viant/afs v1.26.2 - github.com/viant/afsc v1.9.1 + github.com/stretchr/testify v1.11.1 + github.com/viant/afs v1.29.0 + github.com/viant/afsc v1.16.0 github.com/viant/assertly v0.9.1-0.20220620174148-bab013f93a60 github.com/viant/bigquery v0.4.1 github.com/viant/cloudless v1.12.0 @@ -26,93 +25,113 @@ require ( github.com/viant/dyndb v0.1.4-0.20221214043424-27654ab6ed9c github.com/viant/gmetric v0.3.2 github.com/viant/godiff v0.4.1 - github.com/viant/parsly v0.3.3-0.20240717150634-e1afaedb691b + github.com/viant/parsly v0.3.3 github.com/viant/pgo v0.11.0 github.com/viant/scy v0.24.0 - github.com/viant/sqlx v0.17.6 - github.com/viant/structql v0.5.2 + github.com/viant/sqlx v0.21.0 + github.com/viant/structql v0.5.4 github.com/viant/toolbox v0.37.0 - github.com/viant/velty v0.2.1-0.20230927172116-ba56497b5c85 + github.com/viant/velty v0.4.0 github.com/viant/xreflect v0.7.3 github.com/viant/xunsafe v0.10.3 - golang.org/x/mod v0.25.0 - golang.org/x/oauth2 v0.30.0 - google.golang.org/api v0.174.0 + golang.org/x/mod v0.28.0 + golang.org/x/oauth2 v0.32.0 + google.golang.org/api v0.201.0 gopkg.in/yaml.v3 v3.0.1 ) require ( github.com/viant/govalidator v0.3.1 - github.com/viant/sqlparser v0.8.1 + github.com/viant/sqlparser v0.11.1-0.20260224194657-0470849e3588 ) require ( github.com/viant/aerospike v0.2.11-0.20241108195857-ed524b97800d github.com/viant/firebase v0.1.1 - github.com/viant/jsonrpc v0.7.2 - github.com/viant/mcp v0.4.3 - github.com/viant/mcp-protocol v0.4.4 - github.com/viant/structology v0.6.1 - github.com/viant/tagly v0.2.2 - github.com/viant/xdatly v0.5.4-0.20250806192028-819cadf93282 + github.com/viant/jsonrpc v0.17.0 + github.com/viant/mcp v0.11.0 + github.com/viant/mcp-protocol v0.11.0 + github.com/viant/structology v0.8.0 + github.com/viant/tagly v0.3.0 + github.com/viant/x v0.4.0 + github.com/viant/xdatly v0.5.4-0.20251113181159-0ac8b8b0ff3a github.com/viant/xdatly/extension v0.0.0-20231013204918-ecf3c2edf259 - github.com/viant/xdatly/handler v0.0.0-20250806192028-819cadf93282 + github.com/viant/xdatly/handler v0.0.0-20251208172928-dd34b7f09fd5 github.com/viant/xdatly/types/core v0.0.0-20250307183722-8c84fc717b52 github.com/viant/xdatly/types/custom v0.0.0-20240801144911-4c2bfca4c23a github.com/viant/xlsy v0.3.1 github.com/viant/xmlify v0.1.1 - golang.org/x/net v0.40.0 - golang.org/x/tools v0.33.0 + golang.org/x/net v0.46.1-0.20251013234738-63d1a5100f82 + golang.org/x/tools v0.37.0 modernc.org/sqlite v1.18.1 ) require ( - cloud.google.com/go v0.112.1 // indirect - cloud.google.com/go/auth v0.2.0 // indirect - cloud.google.com/go/auth/oauth2adapt v0.2.0 // indirect - cloud.google.com/go/compute/metadata v0.3.0 // indirect - cloud.google.com/go/firestore v1.15.0 // indirect - cloud.google.com/go/iam v1.1.7 // indirect - cloud.google.com/go/longrunning v0.5.5 // indirect - cloud.google.com/go/secretmanager v1.11.5 // indirect - cloud.google.com/go/storage v1.40.0 // indirect + cel.dev/expr v0.24.0 // indirect + cloud.google.com/go v0.116.0 // indirect + cloud.google.com/go/auth v0.9.8 // indirect + cloud.google.com/go/auth/oauth2adapt v0.2.4 // indirect + cloud.google.com/go/compute/metadata v0.9.0 // indirect + cloud.google.com/go/firestore v1.17.0 // indirect + cloud.google.com/go/iam v1.2.1 // indirect + cloud.google.com/go/longrunning v0.6.1 // indirect + cloud.google.com/go/monitoring v1.21.1 // indirect + cloud.google.com/go/secretmanager v1.14.1 // indirect + cloud.google.com/go/storage v1.45.0 // indirect firebase.google.com/go v3.13.0+incompatible // indirect firebase.google.com/go/v4 v4.14.0 // indirect + github.com/GoogleCloudPlatform/opentelemetry-operations-go/detectors/gcp v1.30.0 // indirect + github.com/GoogleCloudPlatform/opentelemetry-operations-go/exporter/metric v0.48.1 // indirect + github.com/GoogleCloudPlatform/opentelemetry-operations-go/internal/resourcemapping v0.48.1 // indirect github.com/MicahParks/keyfunc v1.9.0 // indirect github.com/aerospike/aerospike-client-go/v6 v6.15.1 // indirect github.com/aws/aws-sdk-go v1.51.23 // indirect - github.com/aws/aws-sdk-go-v2 v1.30.3 // indirect - github.com/aws/aws-sdk-go-v2/config v1.27.11 // indirect - github.com/aws/aws-sdk-go-v2/credentials v1.17.26 // indirect + github.com/aws/aws-sdk-go-v2 v1.32.2 // indirect + github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.6 // indirect + github.com/aws/aws-sdk-go-v2/config v1.28.0 // indirect + github.com/aws/aws-sdk-go-v2/credentials v1.17.41 // indirect github.com/aws/aws-sdk-go-v2/feature/dynamodb/attributevalue v1.10.7 // indirect - github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.11 // indirect - github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.15 // indirect - github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.15 // indirect - github.com/aws/aws-sdk-go-v2/internal/ini v1.8.0 // indirect + github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.17 // indirect + github.com/aws/aws-sdk-go-v2/feature/s3/manager v1.17.33 // indirect + github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.21 // indirect + github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.21 // indirect + github.com/aws/aws-sdk-go-v2/internal/ini v1.8.1 // indirect + github.com/aws/aws-sdk-go-v2/internal/v4a v1.3.21 // indirect github.com/aws/aws-sdk-go-v2/service/dynamodb v1.17.8 // indirect github.com/aws/aws-sdk-go-v2/service/dynamodbstreams v1.13.27 // indirect - github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.11.3 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.12.0 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.4.2 // indirect github.com/aws/aws-sdk-go-v2/service/internal/endpoint-discovery v1.7.20 // indirect - github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.11.17 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.12.2 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.18.2 // indirect + github.com/aws/aws-sdk-go-v2/service/s3 v1.66.0 // indirect + github.com/aws/aws-sdk-go-v2/service/secretsmanager v1.34.2 // indirect github.com/aws/aws-sdk-go-v2/service/sns v1.31.3 // indirect github.com/aws/aws-sdk-go-v2/service/sqs v1.34.3 // indirect - github.com/aws/aws-sdk-go-v2/service/sso v1.22.3 // indirect - github.com/aws/aws-sdk-go-v2/service/ssooidc v1.26.4 // indirect - github.com/aws/aws-sdk-go-v2/service/sts v1.30.3 // indirect - github.com/aws/smithy-go v1.20.3 // indirect - github.com/davecgh/go-spew v1.1.1 // indirect + github.com/aws/aws-sdk-go-v2/service/ssm v1.55.2 // indirect + github.com/aws/aws-sdk-go-v2/service/sso v1.24.2 // indirect + github.com/aws/aws-sdk-go-v2/service/ssooidc v1.28.2 // indirect + github.com/aws/aws-sdk-go-v2/service/sts v1.32.2 // indirect + github.com/aws/smithy-go v1.22.0 // indirect + github.com/cespare/xxhash/v2 v2.3.0 // indirect + github.com/cncf/xds/go v0.0.0-20251022180443-0feb69152e9f // indirect + github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect github.com/decred/dcrd/dcrec/secp256k1/v4 v4.2.0 // indirect + github.com/envoyproxy/go-control-plane/envoy v1.35.0 // indirect + github.com/envoyproxy/protoc-gen-validate v1.2.1 // indirect github.com/felixge/httpsnoop v1.0.4 // indirect github.com/go-errors/errors v1.5.1 // indirect - github.com/go-logr/logr v1.4.1 // indirect + github.com/go-jose/go-jose/v4 v4.1.3 // indirect + github.com/go-logr/logr v1.4.3 // indirect github.com/go-logr/stdr v1.2.2 // indirect github.com/go-viper/mapstructure/v2 v2.2.1 // indirect + github.com/goccy/go-json v0.10.2 // indirect github.com/golang-jwt/jwt/v5 v5.2.2 // indirect - github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect + github.com/golang/groupcache v0.0.0-20241129210726-2c02b8208cf8 // indirect github.com/golang/protobuf v1.5.4 // indirect - github.com/google/s2a-go v0.1.7 // indirect - github.com/googleapis/enterprise-certificate-proxy v0.3.2 // indirect - github.com/googleapis/gax-go/v2 v2.12.3 // indirect + github.com/google/s2a-go v0.1.8 // indirect + github.com/googleapis/enterprise-certificate-proxy v0.3.4 // indirect + github.com/googleapis/gax-go/v2 v2.13.0 // indirect github.com/jmespath/go-jmespath v0.4.0 // indirect github.com/kballard/go-shellquote v0.0.0-20180428030007-95032a82bc51 // indirect github.com/lestrrat-go/backoff/v2 v2.0.8 // indirect @@ -125,35 +144,41 @@ require ( github.com/mazznoer/csscolorparser v0.1.3 // indirect github.com/mohae/deepcopy v0.0.0-20170929034955-c48cc78d4826 // indirect github.com/nxadm/tail v1.4.8 // indirect - github.com/pmezard/go-difflib v1.0.0 // indirect + github.com/planetscale/vtprotobuf v0.6.1-0.20240319094008-0393e58bdf10 // indirect + github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect github.com/remyoudompheng/bigfft v0.0.0-20200410134404-eec4a21b6bb0 // indirect github.com/richardlehane/mscfb v1.0.4 // indirect github.com/richardlehane/msoleps v1.0.3 // indirect + github.com/spiffe/go-spiffe/v2 v2.6.0 // indirect + github.com/viant/gosh v0.2.1 // indirect github.com/viant/igo v0.2.0 // indirect - github.com/viant/x v0.3.0 // indirect github.com/xuri/efp v0.0.0-20230802181842-ad255f2331ca // indirect github.com/xuri/excelize/v2 v2.8.0 // indirect github.com/xuri/nfp v0.0.0-20230819163627-dc951e3ffe1a // indirect github.com/yuin/gopher-lua v1.1.1 // indirect go.opencensus.io v0.24.0 // indirect - go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.49.0 // indirect - go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.49.0 // indirect - go.opentelemetry.io/otel v1.24.0 // indirect - go.opentelemetry.io/otel/metric v1.24.0 // indirect - go.opentelemetry.io/otel/trace v1.24.0 // indirect - golang.org/x/crypto v0.39.0 // indirect - golang.org/x/sync v0.15.0 // indirect - golang.org/x/sys v0.33.0 // indirect - golang.org/x/term v0.32.0 // indirect - golang.org/x/text v0.26.0 // indirect - golang.org/x/time v0.5.0 // indirect + go.opentelemetry.io/auto/sdk v1.2.1 // indirect + go.opentelemetry.io/contrib/detectors/gcp v1.38.0 // indirect + go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.54.0 // indirect + go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.54.0 // indirect + go.opentelemetry.io/otel v1.38.0 // indirect + go.opentelemetry.io/otel/metric v1.38.0 // indirect + go.opentelemetry.io/otel/sdk v1.38.0 // indirect + go.opentelemetry.io/otel/sdk/metric v1.38.0 // indirect + go.opentelemetry.io/otel/trace v1.38.0 // indirect + golang.org/x/crypto v0.43.0 // indirect + golang.org/x/sync v0.17.0 // indirect + golang.org/x/sys v0.37.0 // indirect + golang.org/x/term v0.36.0 // indirect + golang.org/x/text v0.30.0 // indirect + golang.org/x/time v0.7.0 // indirect google.golang.org/appengine v1.6.8 // indirect google.golang.org/appengine/v2 v2.0.2 // indirect - google.golang.org/genproto v0.0.0-20240227224415-6ceb2ff114de // indirect - google.golang.org/genproto/googleapis/api v0.0.0-20240314234333-6e1732d8331c // indirect - google.golang.org/genproto/googleapis/rpc v0.0.0-20240415180920-8c6c420018be // indirect - google.golang.org/grpc v1.63.2 // indirect - google.golang.org/protobuf v1.33.0 // indirect + google.golang.org/genproto v0.0.0-20241015192408-796eee8c2d53 // indirect + google.golang.org/genproto/googleapis/api v0.0.0-20251022142026-3a174f9686a8 // indirect + google.golang.org/genproto/googleapis/rpc v0.0.0-20251022142026-3a174f9686a8 // indirect + google.golang.org/grpc v1.77.0 // indirect + google.golang.org/protobuf v1.36.10 // indirect gopkg.in/yaml.v2 v2.4.0 // indirect lukechampine.com/uint128 v1.2.0 // indirect modernc.org/cc/v3 v3.36.3 // indirect diff --git a/go.sum b/go.sum index 736b0b41e..04ca6ae0b 100644 --- a/go.sum +++ b/go.sum @@ -1,3 +1,5 @@ +cel.dev/expr v0.24.0 h1:56OvJKSH3hDGL0ml5uSxZmz3/3Pq4tJ+fb1unVLAFcY= +cel.dev/expr v0.24.0/go.mod h1:hLPLo1W4QUmuYdA72RBX06QTs6MXw941piREPl3Yfiw= cloud.google.com/go v0.26.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= cloud.google.com/go v0.31.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= cloud.google.com/go v0.34.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= @@ -39,8 +41,8 @@ cloud.google.com/go v0.105.0/go.mod h1:PrLgOJNe5nfE9UMxKxgXj4mD3voiP+YQ6gdt6KMFO cloud.google.com/go v0.107.0/go.mod h1:wpc2eNrD7hXUTy8EKS10jkxpZBjASrORK7goS+3YX2I= cloud.google.com/go v0.110.0/go.mod h1:SJnCLqQ0FCFGSZMUNUf84MV3Aia54kn7pi8st7tMzaY= cloud.google.com/go v0.110.2/go.mod h1:k04UEeEtb6ZBRTv3dZz4CeJC3jKGxyhl0sAiVVquxiw= -cloud.google.com/go v0.112.1 h1:uJSeirPke5UNZHIb4SxfZklVSiWWVqW4oXlETwZziwM= -cloud.google.com/go v0.112.1/go.mod h1:+Vbu+Y1UU+I1rjmzeMOb/8RfkKJK2Gyxi1X6jJCZLo4= +cloud.google.com/go v0.116.0 h1:B3fRrSDkLRt5qSHWe40ERJvhvnQwdZiHu0bJOpldweE= +cloud.google.com/go v0.116.0/go.mod h1:cEPSRWPzZEswwdr9BxE6ChEn01dWlTaF05LiC2Xs70U= cloud.google.com/go/accessapproval v1.4.0/go.mod h1:zybIuC3KpDOvotz59lFe5qxRZx6C75OtwbisN56xYB4= cloud.google.com/go/accessapproval v1.5.0/go.mod h1:HFy3tuiGvMdcd/u+Cu5b9NkO1pEICJ46IR82PoUdplw= cloud.google.com/go/accessapproval v1.6.0/go.mod h1:R0EiYnwV5fsRFiKZkPHr6mwyk2wxUJ30nL4j2pcFY2E= @@ -102,10 +104,10 @@ cloud.google.com/go/assuredworkloads v1.7.0/go.mod h1:z/736/oNmtGAyU47reJgGN+KVo cloud.google.com/go/assuredworkloads v1.8.0/go.mod h1:AsX2cqyNCOvEQC8RMPnoc0yEarXQk6WEKkxYfL6kGIo= cloud.google.com/go/assuredworkloads v1.9.0/go.mod h1:kFuI1P78bplYtT77Tb1hi0FMxM0vVpRC7VVoJC3ZoT0= cloud.google.com/go/assuredworkloads v1.10.0/go.mod h1:kwdUQuXcedVdsIaKgKTp9t0UJkE5+PAVNhdQm4ZVq2E= -cloud.google.com/go/auth v0.2.0 h1:y6oTcpMSbOcXbwYgUUrvI+mrQ2xbrcdpPgtVbCGTLTk= -cloud.google.com/go/auth v0.2.0/go.mod h1:+yb+oy3/P0geX6DLKlqiGHARGR6EX2GRtYCzWOCQSbU= -cloud.google.com/go/auth/oauth2adapt v0.2.0 h1:FR8zevgQwu+8CqiOT5r6xCmJa3pJC/wdXEEPF1OkNhA= -cloud.google.com/go/auth/oauth2adapt v0.2.0/go.mod h1:AfqujpDAlTfLfeCIl/HJZZlIxD8+nJoZ5e0x1IxGq5k= +cloud.google.com/go/auth v0.9.8 h1:+CSJ0Gw9iVeSENVCKJoLHhdUykDgXSc4Qn+gu2BRtR8= +cloud.google.com/go/auth v0.9.8/go.mod h1:xxA5AqpDrvS+Gkmo9RqrGGRh6WSNKKOXhY3zNOr38tI= +cloud.google.com/go/auth/oauth2adapt v0.2.4 h1:0GWE/FUsXhf6C+jAkWgYm7X9tK8cuEIfy19DBn6B6bY= +cloud.google.com/go/auth/oauth2adapt v0.2.4/go.mod h1:jC/jOpwFP6JBxhB3P5Rr0a9HLMC/Pe3eaL4NmdvqPtc= cloud.google.com/go/automl v1.5.0/go.mod h1:34EjfoFGMZ5sgJ9EoLsRtdPSNZLcfflJR39VbVNS2M0= cloud.google.com/go/automl v1.6.0/go.mod h1:ugf8a6Fx+zP0D59WLhqgTDsQI9w07o64uf/Is3Nh5p8= cloud.google.com/go/automl v1.7.0/go.mod h1:RL9MYCCsJEOmt0Wf3z9uzG0a7adTT1fe+aObgSpkCt8= @@ -186,8 +188,8 @@ cloud.google.com/go/compute/metadata v0.1.0/go.mod h1:Z1VN+bulIf6bt4P/C37K4DyZYZ cloud.google.com/go/compute/metadata v0.2.0/go.mod h1:zFmK7XCadkQkj6TtorcaGlCW1hT1fIilQDwofLpJ20k= cloud.google.com/go/compute/metadata v0.2.1/go.mod h1:jgHgmJd2RKBGzXqF5LR2EZMGxBkeanZ9wwa75XHJgOM= cloud.google.com/go/compute/metadata v0.2.3/go.mod h1:VAV5nSsACxMJvgaAuX6Pk2AawlZn8kiOGuCv6gTkwuA= -cloud.google.com/go/compute/metadata v0.3.0 h1:Tz+eQXMEqDIKRsmY3cHTL6FVaynIjX2QxYC4trgAKZc= -cloud.google.com/go/compute/metadata v0.3.0/go.mod h1:zFmK7XCadkQkj6TtorcaGlCW1hT1fIilQDwofLpJ20k= +cloud.google.com/go/compute/metadata v0.9.0 h1:pDUj4QMoPejqq20dK0Pg2N4yG9zIkYGdBtwLoEkH9Zs= +cloud.google.com/go/compute/metadata v0.9.0/go.mod h1:E0bWwX5wTnLPedCKqk3pJmVgCBSM6qQI1yTBdEb3C10= cloud.google.com/go/contactcenterinsights v1.3.0/go.mod h1:Eu2oemoePuEFc/xKFPjbTuPSj0fYJcPls9TFlPNnHHY= cloud.google.com/go/contactcenterinsights v1.4.0/go.mod h1:L2YzkGbPsv+vMQMCADxJoT9YiTTnSEd6fEvCeHTYVck= cloud.google.com/go/contactcenterinsights v1.6.0/go.mod h1:IIDlT6CLcDoyv79kDv8iWxMSTZhLxSCofVV5W6YFM/w= @@ -283,8 +285,8 @@ cloud.google.com/go/filestore v1.4.0/go.mod h1:PaG5oDfo9r224f8OYXURtAsY+Fbyq/bLY cloud.google.com/go/filestore v1.5.0/go.mod h1:FqBXDWBp4YLHqRnVGveOkHDf8svj9r5+mUDLupOWEDs= cloud.google.com/go/filestore v1.6.0/go.mod h1:di5unNuss/qfZTw2U9nhFqo8/ZDSc466dre85Kydllg= cloud.google.com/go/firestore v1.9.0/go.mod h1:HMkjKHNTtRyZNiMzu7YAsLr9K3X2udY2AMwDaMEQiiE= -cloud.google.com/go/firestore v1.15.0 h1:/k8ppuWOtNuDHt2tsRV42yI21uaGnKDEQnRFeBpbFF8= -cloud.google.com/go/firestore v1.15.0/go.mod h1:GWOxFXcv8GZUtYpWHw/w6IuYNux/BtmeVTMmjrm4yhk= +cloud.google.com/go/firestore v1.17.0 h1:iEd1LBbkDZTFsLw3sTH50eyg4qe8eoG6CjocmEXO9aQ= +cloud.google.com/go/firestore v1.17.0/go.mod h1:69uPx1papBsY8ZETooc71fOhoKkD70Q1DwMrtKuOT/Y= cloud.google.com/go/functions v1.6.0/go.mod h1:3H1UA3qiIPRWD7PeZKLvHZ9SaQhR26XIJcC0A5GbvAk= cloud.google.com/go/functions v1.7.0/go.mod h1:+d+QBcWM+RsrgZfV9xo6KfA1GlzJfxcfZcRPEhDDfzg= cloud.google.com/go/functions v1.8.0/go.mod h1:RTZ4/HsQjIqIYP9a9YPbU+QFoQsAlYgrwOXJWHn1POY= @@ -323,8 +325,8 @@ cloud.google.com/go/iam v0.8.0/go.mod h1:lga0/y3iH6CX7sYqypWJ33hf7kkfXJag67naqGE cloud.google.com/go/iam v0.11.0/go.mod h1:9PiLDanza5D+oWFZiH1uG+RnRCfEGKoyl6yo4cgWZGY= cloud.google.com/go/iam v0.12.0/go.mod h1:knyHGviacl11zrtZUoDuYpDgLjvr28sLQaG0YB2GYAY= cloud.google.com/go/iam v0.13.0/go.mod h1:ljOg+rcNfzZ5d6f1nAUJ8ZIxOaZUVoS14bKCtaLZ/D0= -cloud.google.com/go/iam v1.1.7 h1:z4VHOhwKLF/+UYXAJDFwGtNF0b6gjsW1Pk9Ml0U/IoM= -cloud.google.com/go/iam v1.1.7/go.mod h1:J4PMPg8TtyurAUvSmPj8FF3EDgY1SPRZxcUGrn7WXGA= +cloud.google.com/go/iam v1.2.1 h1:QFct02HRb7H12J/3utj0qf5tobFh9V4vR6h9eX5EBRU= +cloud.google.com/go/iam v1.2.1/go.mod h1:3VUIJDPpwT6p/amXRC5GY8fCCh70lxPygguVtI0Z4/g= cloud.google.com/go/iap v1.4.0/go.mod h1:RGFwRJdihTINIe4wZ2iCP0zF/qu18ZwyKxrhMhygBEc= cloud.google.com/go/iap v1.5.0/go.mod h1:UH/CGgKd4KyohZL5Pt0jSKE4m3FR51qg6FKQ/z/Ix9A= cloud.google.com/go/iap v1.6.0/go.mod h1:NSuvI9C/j7UdjGjIde7t7HBz+QTwBcapPE07+sSRcLk= @@ -354,11 +356,13 @@ cloud.google.com/go/lifesciences v0.6.0/go.mod h1:ddj6tSX/7BOnhxCSd3ZcETvtNr8NZ6 cloud.google.com/go/lifesciences v0.8.0/go.mod h1:lFxiEOMqII6XggGbOnKiyZ7IBwoIqA84ClvoezaA/bo= cloud.google.com/go/logging v1.6.1/go.mod h1:5ZO0mHHbvm8gEmeEUHrmDlTDSu5imF6MUP9OfilNXBw= cloud.google.com/go/logging v1.7.0/go.mod h1:3xjP2CjkM3ZkO73aj4ASA5wRPGGCRrPIAeNqVNkzY8M= +cloud.google.com/go/logging v1.11.0 h1:v3ktVzXMV7CwHq1MBF65wcqLMA7i+z3YxbUsoK7mOKs= +cloud.google.com/go/logging v1.11.0/go.mod h1:5LDiJC/RxTt+fHc1LAt20R9TKiUTReDg6RuuFOZ67+A= cloud.google.com/go/longrunning v0.1.1/go.mod h1:UUFxuDWkv22EuY93jjmDMFT5GPQKeFVJBIF6QlTqdsE= cloud.google.com/go/longrunning v0.3.0/go.mod h1:qth9Y41RRSUE69rDcOn6DdK3HfQfsUI0YSmW3iIlLJc= cloud.google.com/go/longrunning v0.4.1/go.mod h1:4iWDqhBZ70CvZ6BfETbvam3T8FMvLK+eFj0E6AaRQTo= -cloud.google.com/go/longrunning v0.5.5 h1:GOE6pZFdSrTb4KAiKnXsJBtlE6mEyaW44oKyMILWnOg= -cloud.google.com/go/longrunning v0.5.5/go.mod h1:WV2LAxD8/rg5Z1cNW6FJ/ZpX4E4VnDnoTk0yawPBB7s= +cloud.google.com/go/longrunning v0.6.1 h1:lOLTFxYpr8hcRtcwWir5ITh1PAKUD/sG2lKrTSYjyMc= +cloud.google.com/go/longrunning v0.6.1/go.mod h1:nHISoOZpBcmlwbJmiVk5oDRz0qG/ZxPynEGs1iZ79s0= cloud.google.com/go/managedidentities v1.3.0/go.mod h1:UzlW3cBOiPrzucO5qWkNkh0w33KFtBJU281hacNvsdE= cloud.google.com/go/managedidentities v1.4.0/go.mod h1:NWSBYbEMgqmbZsLIyKvxrYbtqOsxY1ZrGM+9RgDqInM= cloud.google.com/go/managedidentities v1.5.0/go.mod h1:+dWcZ0JlUmpuxpIDfyP5pP5y0bLdRwOS4Lp7gMni/LA= @@ -382,6 +386,8 @@ cloud.google.com/go/monitoring v1.7.0/go.mod h1:HpYse6kkGo//7p6sT0wsIC6IBDET0RhI cloud.google.com/go/monitoring v1.8.0/go.mod h1:E7PtoMJ1kQXWxPjB6mv2fhC5/15jInuulFdYYtlcvT4= cloud.google.com/go/monitoring v1.12.0/go.mod h1:yx8Jj2fZNEkL/GYZyTLS4ZtZEZN8WtDEiEqG4kLK50w= cloud.google.com/go/monitoring v1.13.0/go.mod h1:k2yMBAB1H9JT/QETjNkgdCGD9bPF712XiLTVr+cBrpw= +cloud.google.com/go/monitoring v1.21.1 h1:zWtbIoBMnU5LP9A/fz8LmWMGHpk4skdfeiaa66QdFGc= +cloud.google.com/go/monitoring v1.21.1/go.mod h1:Rj++LKrlht9uBi8+Eb530dIrzG/cU/lB8mt+lbeFK1c= cloud.google.com/go/networkconnectivity v1.4.0/go.mod h1:nOl7YL8odKyAOtzNX73/M5/mGZgqqMeryi6UPZTk/rA= cloud.google.com/go/networkconnectivity v1.5.0/go.mod h1:3GzqJx7uhtlM3kln0+x5wyFvuVH1pIBJjhCpjzSt75o= cloud.google.com/go/networkconnectivity v1.6.0/go.mod h1:OJOoEXW+0LAxHh89nXd64uGG+FbQoeH8DtxCHVOMlaM= @@ -490,8 +496,8 @@ cloud.google.com/go/secretmanager v1.6.0/go.mod h1:awVa/OXF6IiyaU1wQ34inzQNc4ISI cloud.google.com/go/secretmanager v1.8.0/go.mod h1:hnVgi/bN5MYHd3Gt0SPuTPPp5ENina1/LxM+2W9U9J4= cloud.google.com/go/secretmanager v1.9.0/go.mod h1:b71qH2l1yHmWQHt9LC80akm86mX8AL6X1MA01dW8ht4= cloud.google.com/go/secretmanager v1.10.0/go.mod h1:MfnrdvKMPNra9aZtQFvBcvRU54hbPD8/HayQdlUgJpU= -cloud.google.com/go/secretmanager v1.11.5 h1:82fpF5vBBvu9XW4qj0FU2C6qVMtj1RM/XHwKXUEAfYY= -cloud.google.com/go/secretmanager v1.11.5/go.mod h1:eAGv+DaCHkeVyQi0BeXgAHOU0RdrMeZIASKc+S7VqH4= +cloud.google.com/go/secretmanager v1.14.1 h1:xlWSIg8rtBn5qCr2f3XtQP19+5COyf/ll49SEvi/0vM= +cloud.google.com/go/secretmanager v1.14.1/go.mod h1:L+gO+u2JA9CCyXpSR8gDH0o8EV7i/f0jdBOrUXcIV0U= cloud.google.com/go/security v1.5.0/go.mod h1:lgxGdyOKKjHL4YG3/YwIL2zLqMFCKs0UbQwgyZmfJl4= cloud.google.com/go/security v1.7.0/go.mod h1:mZklORHl6Bg7CNnnjLH//0UlAlaXqiG7Lb9PsPXLfD0= cloud.google.com/go/security v1.8.0/go.mod h1:hAQOwgmaHhztFhiQ41CjDODdWP0+AE1B3sX4OFlq+GU= @@ -547,8 +553,8 @@ cloud.google.com/go/storage v1.23.0/go.mod h1:vOEEDNFnciUMhBeT6hsJIn3ieU5cFRmzeL cloud.google.com/go/storage v1.27.0/go.mod h1:x9DOL8TK/ygDUMieqwfhdpQryTeEkhGKMi80i/iqR2s= cloud.google.com/go/storage v1.28.1/go.mod h1:Qnisd4CqDdo6BGs2AD5LLnEsmSQ80wQ5ogcBBKhU86Y= cloud.google.com/go/storage v1.29.0/go.mod h1:4puEjyTKnku6gfKoTfNOU/W+a9JyuVNxjpS5GBrB8h4= -cloud.google.com/go/storage v1.40.0 h1:VEpDQV5CJxFmJ6ueWNsKxcr1QAYOXEgxDa+sBbJahPw= -cloud.google.com/go/storage v1.40.0/go.mod h1:Rrj7/hKlG87BLqDJYtwR0fbPld8uJPbQ2ucUMY7Ir0g= +cloud.google.com/go/storage v1.45.0 h1:5av0QcIVj77t+44mV4gffFC/LscFRUhto6UBMB5SimM= +cloud.google.com/go/storage v1.45.0/go.mod h1:wpPblkIuMP5jCB/E48Pz9zIo2S/zD8g+ITmxKkPCITE= cloud.google.com/go/storagetransfer v1.5.0/go.mod h1:dxNzUopWy7RQevYFHewchb29POFv3/AaBgnhqzqiK0w= cloud.google.com/go/storagetransfer v1.6.0/go.mod h1:y77xm4CQV/ZhFZH75PLEXY0ROiS7Gh6pSKrM8dJyg6I= cloud.google.com/go/storagetransfer v1.7.0/go.mod h1:8Giuj1QNb1kfLAiWM1bN6dHzfdlDAVC9rv9abHot2W4= @@ -568,6 +574,8 @@ cloud.google.com/go/trace v1.3.0/go.mod h1:FFUE83d9Ca57C+K8rDl/Ih8LwOzWIV1krKgxg cloud.google.com/go/trace v1.4.0/go.mod h1:UG0v8UBqzusp+z63o7FK74SdFE+AXpCLdFb1rshXG+Y= cloud.google.com/go/trace v1.8.0/go.mod h1:zH7vcsbAhklH8hWFig58HvxcxyQbaIqMarMg9hn5ECA= cloud.google.com/go/trace v1.9.0/go.mod h1:lOQqpE5IaWY0Ixg7/r2SjixMuc6lfTFeO4QGM4dQWOk= +cloud.google.com/go/trace v1.11.1 h1:UNqdP+HYYtnm6lb91aNA5JQ0X14GnxkABGlfz2PzPew= +cloud.google.com/go/trace v1.11.1/go.mod h1:IQKNQuBzH72EGaXEodKlNJrWykGZxet2zgjtS60OtjA= cloud.google.com/go/translate v1.3.0/go.mod h1:gzMUwRjvOqj5i69y/LYLd8RrNQk+hOmIXTi9+nb3Djs= cloud.google.com/go/translate v1.4.0/go.mod h1:06Dn/ppvLD6WvA5Rhdp029IX2Mi3Mn7fpMRLPvXT5Wg= cloud.google.com/go/translate v1.5.0/go.mod h1:29YDSYveqqpA1CQFD7NQuP49xymq17RXNaUDdc0mNu0= @@ -628,6 +636,14 @@ git.apache.org/thrift.git v0.0.0-20180902110319-2566ecd5d999/go.mod h1:fPE2ZNJGy git.sr.ht/~sbinet/gg v0.3.1/go.mod h1:KGYtlADtqsqANL9ueOFkWymvzUvLMQllU5Ixo+8v3pc= github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= github.com/BurntSushi/xgb v0.0.0-20160522181843-27f122750802/go.mod h1:IVnqGOEym/WlBOVXweHU+Q+/VP0lqqI8lqeDx9IjBqo= +github.com/GoogleCloudPlatform/opentelemetry-operations-go/detectors/gcp v1.30.0 h1:sBEjpZlNHzK1voKq9695PJSX2o5NEXl7/OL3coiIY0c= +github.com/GoogleCloudPlatform/opentelemetry-operations-go/detectors/gcp v1.30.0/go.mod h1:P4WPRUkOhJC13W//jWpyfJNDAIpvRbAUIYLX/4jtlE0= +github.com/GoogleCloudPlatform/opentelemetry-operations-go/exporter/metric v0.48.1 h1:UQ0AhxogsIRZDkElkblfnwjc3IaltCm2HUMvezQaL7s= +github.com/GoogleCloudPlatform/opentelemetry-operations-go/exporter/metric v0.48.1/go.mod h1:jyqM3eLpJ3IbIFDTKVz2rF9T/xWGW0rIriGwnz8l9Tk= +github.com/GoogleCloudPlatform/opentelemetry-operations-go/internal/cloudmock v0.48.1 h1:oTX4vsorBZo/Zdum6OKPA4o7544hm6smoRv1QjpTwGo= +github.com/GoogleCloudPlatform/opentelemetry-operations-go/internal/cloudmock v0.48.1/go.mod h1:0wEl7vrAD8mehJyohS9HZy+WyEOaQO2mJx86Cvh93kM= +github.com/GoogleCloudPlatform/opentelemetry-operations-go/internal/resourcemapping v0.48.1 h1:8nn+rsCvTq9axyEh382S0PFLBeaFwNsT43IrPWzctRU= +github.com/GoogleCloudPlatform/opentelemetry-operations-go/internal/resourcemapping v0.48.1/go.mod h1:viRWSEhtMZqz1rhwmOVKkWl6SwmVowfL9O2YR5gI2PE= github.com/JohnCGriffin/overflow v0.0.0-20211019200055-46fa312c352c/go.mod h1:X0CRv0ky0k6m906ixxpzmDRLvX58TFUKS2eePweuyxk= github.com/MicahParks/keyfunc v1.9.0 h1:lhKd5xrFHLNOWrDc4Tyb/Q1AJ4LCzQ48GVJyVIID3+o= github.com/MicahParks/keyfunc v1.9.0/go.mod h1:IdnCilugA0O/99dW+/MkvlyrsX8+L8+x95xuVNtM5jw= @@ -652,48 +668,64 @@ github.com/aws/aws-lambda-go v1.31.0/go.mod h1:IF5Q7wj4VyZyUFnZ54IQqeWtctHQ9tz+K github.com/aws/aws-sdk-go v1.51.23 h1:/3TEdsEE/aHmdKGw2NrOp7Sdea76zfffGkTTSXTsDxY= github.com/aws/aws-sdk-go v1.51.23/go.mod h1:LF8svs817+Nz+DmiMQKTO3ubZ/6IaTpq3TjupRn3Eqk= github.com/aws/aws-sdk-go-v2 v1.17.2/go.mod h1:uzbQtefpm44goOPmdKyAlXSNcwlRgF3ePWVW6EtJvvw= -github.com/aws/aws-sdk-go-v2 v1.30.3 h1:jUeBtG0Ih+ZIFH0F4UkmL9w3cSpaMv9tYYDbzILP8dY= -github.com/aws/aws-sdk-go-v2 v1.30.3/go.mod h1:nIQjQVp5sfpQcTc9mPSr1B0PaWK5ByX9MOoDadSN4lc= -github.com/aws/aws-sdk-go-v2/config v1.27.11 h1:f47rANd2LQEYHda2ddSCKYId18/8BhSRM4BULGmfgNA= -github.com/aws/aws-sdk-go-v2/config v1.27.11/go.mod h1:SMsV78RIOYdve1vf36z8LmnszlRWkwMQtomCAI0/mIE= -github.com/aws/aws-sdk-go-v2/credentials v1.17.26 h1:tsm8g/nJxi8+/7XyJJcP2dLrnK/5rkFp6+i2nhmz5fk= -github.com/aws/aws-sdk-go-v2/credentials v1.17.26/go.mod h1:3vAM49zkIa3q8WT6o9Ve5Z0vdByDMwmdScO0zvThTgI= +github.com/aws/aws-sdk-go-v2 v1.32.2 h1:AkNLZEyYMLnx/Q/mSKkcMqwNFXMAvFto9bNsHqcTduI= +github.com/aws/aws-sdk-go-v2 v1.32.2/go.mod h1:2SK5n0a2karNTv5tbP1SjsX0uhttou00v/HpXKM1ZUo= +github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.6 h1:pT3hpW0cOHRJx8Y0DfJUEQuqPild8jRGmSFmBgvydr0= +github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.6/go.mod h1:j/I2++U0xX+cr44QjHay4Cvxj6FUbnxrgmqN3H1jTZA= +github.com/aws/aws-sdk-go-v2/config v1.28.0 h1:FosVYWcqEtWNxHn8gB/Vs6jOlNwSoyOCA/g/sxyySOQ= +github.com/aws/aws-sdk-go-v2/config v1.28.0/go.mod h1:pYhbtvg1siOOg8h5an77rXle9tVG8T+BWLWAo7cOukc= +github.com/aws/aws-sdk-go-v2/credentials v1.17.41 h1:7gXo+Axmp+R4Z+AK8YFQO0ZV3L0gizGINCOWxSLY9W8= +github.com/aws/aws-sdk-go-v2/credentials v1.17.41/go.mod h1:u4Eb8d3394YLubphT4jLEwN1rLNq2wFOlT6OuxFwPzU= github.com/aws/aws-sdk-go-v2/feature/dynamodb/attributevalue v1.10.7 h1:CyuByiiCA4lPfU8RaHJh2wIYYn0hkFlOkMfWkVY67Mc= github.com/aws/aws-sdk-go-v2/feature/dynamodb/attributevalue v1.10.7/go.mod h1:pAMtgCPVxcKohC/HNI6nLwLeW007eYl3T+pq7yTMV3o= -github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.11 h1:KreluoV8FZDEtI6Co2xuNk/UqI9iwMrOx/87PBNIKqw= -github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.11/go.mod h1:SeSUYBLsMYFoRvHE0Tjvn7kbxaUhl75CJi1sbfhMxkU= +github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.17 h1:TMH3f/SCAWdNtXXVPPu5D6wrr4G5hI1rAxbcocKfC7Q= +github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.17/go.mod h1:1ZRXLdTpzdJb9fwTMXiLipENRxkGMTn1sfKexGllQCw= +github.com/aws/aws-sdk-go-v2/feature/s3/manager v1.17.33 h1:X+4YY5kZRI/cOoSMVMGTqFXHAMg1bvvay7IBcqHpybQ= +github.com/aws/aws-sdk-go-v2/feature/s3/manager v1.17.33/go.mod h1:DPynzu+cn92k5UQ6tZhX+wfTB4ah6QDU/NgdHqatmvk= github.com/aws/aws-sdk-go-v2/internal/configsources v1.1.26/go.mod h1:2E0LdbJW6lbeU4uxjum99GZzI0ZjDpAb0CoSCM0oeEY= -github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.15 h1:SoNJ4RlFEQEbtDcCEt+QG56MY4fm4W8rYirAmq+/DdU= -github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.15/go.mod h1:U9ke74k1n2bf+RIgoX1SXFed1HLs51OgUSs+Ph0KJP8= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.21 h1:UAsR3xA31QGf79WzpG/ixT9FZvQlh5HY1NRqSHBNOCk= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.21/go.mod h1:JNr43NFf5L9YaG3eKTm7HQzls9J+A9YYcGI5Quh1r2Y= github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.4.20/go.mod h1:/+6lSiby8TBFpTVXZgKiN/rCfkYXEGvhlM4zCgPpt7w= -github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.15 h1:C6WHdGnTDIYETAm5iErQUiVNsclNx9qbJVPIt03B6bI= -github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.15/go.mod h1:ZQLZqhcu+JhSrA9/NXRm8SkDvsycE+JkV3WGY41e+IM= -github.com/aws/aws-sdk-go-v2/internal/ini v1.8.0 h1:hT8rVHwugYE2lEfdFE0QWVo81lF7jMrYJVDWI+f+VxU= -github.com/aws/aws-sdk-go-v2/internal/ini v1.8.0/go.mod h1:8tu/lYfQfFe6IGnaOdrpVgEL2IrrDOf6/m9RQum4NkY= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.21 h1:6jZVETqmYCadGFvrYEQfC5fAQmlo80CeL5psbno6r0s= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.21/go.mod h1:1SR0GbLlnN3QUmYaflZNiH1ql+1qrSiB2vwcJ+4UM60= +github.com/aws/aws-sdk-go-v2/internal/ini v1.8.1 h1:VaRN3TlFdd6KxX1x3ILT5ynH6HvKgqdiXoTxAF4HQcQ= +github.com/aws/aws-sdk-go-v2/internal/ini v1.8.1/go.mod h1:FbtygfRFze9usAadmnGJNc8KsP346kEe+y2/oyhGAGc= +github.com/aws/aws-sdk-go-v2/internal/v4a v1.3.21 h1:7edmS3VOBDhK00b/MwGtGglCm7hhwNYnjJs/PgFdMQE= +github.com/aws/aws-sdk-go-v2/internal/v4a v1.3.21/go.mod h1:Q9o5h4HoIWG8XfzxqiuK/CGUbepCJ8uTlaE3bAbxytQ= github.com/aws/aws-sdk-go-v2/service/dynamodb v1.17.8 h1:VgdGaSIoH4JhUZIspT8UgK0aBF85TiLve7VHEx3NfqE= github.com/aws/aws-sdk-go-v2/service/dynamodb v1.17.8/go.mod h1:jvXzk+hVrlkiQOvnq6jH+F6qBK0CEceXkEWugT+4Kdc= github.com/aws/aws-sdk-go-v2/service/dynamodbstreams v1.13.27 h1:7MhqbR+k+b0gbOxp+W8yXgsl/Z5/dtMh85K0WI8X2EA= github.com/aws/aws-sdk-go-v2/service/dynamodbstreams v1.13.27/go.mod h1:wX9QEZJ8Dw1fdAKCOAUmSvAe3wNJFxnE/4AeYc8blGA= github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.9.11/go.mod h1:iV4q2hsqtNECrfmlXyord9u4zyuFEJX9eLgLpSPzWA8= -github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.11.3 h1:dT3MqvGhSoaIhRseqw2I0yH81l7wiR2vjs57O51EAm8= -github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.11.3/go.mod h1:GlAeCkHwugxdHaueRr4nhPuY+WW+gR8UjlcqzPr1SPI= +github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.12.0 h1:TToQNkvGguu209puTojY/ozlqy2d/SFNcoLIqTFi42g= +github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.12.0/go.mod h1:0jp+ltwkf+SwG2fm/PKo8t4y8pJSgOCO4D8Lz3k0aHQ= +github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.4.2 h1:4FMHqLfk0efmTqhXVRL5xYRqlEBNBiRI7N6w4jsEdd4= +github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.4.2/go.mod h1:LWoqeWlK9OZeJxsROW2RqrSPvQHKTpp69r/iDjwsSaw= github.com/aws/aws-sdk-go-v2/service/internal/endpoint-discovery v1.7.20 h1:kSZR22oLBDMtP8ZPGXhz649NU77xsJDG7g3xfT6nHVk= github.com/aws/aws-sdk-go-v2/service/internal/endpoint-discovery v1.7.20/go.mod h1:lxM5qubwGNX29Qy+xTFG8G0r2Mj/TmyC+h3hS/7E4V8= -github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.11.17 h1:HGErhhrxZlQ044RiM+WdoZxp0p+EGM62y3L6pwA4olE= -github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.11.17/go.mod h1:RkZEx4l0EHYDJpWppMJ3nD9wZJAa8/0lq9aVC+r2UII= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.12.2 h1:s7NA1SOw8q/5c0wr8477yOPp0z+uBaXBnLE0XYb0POA= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.12.2/go.mod h1:fnjjWyAW/Pj5HYOxl9LJqWtEwS7W2qgcRLWP+uWbss0= +github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.18.2 h1:t7iUP9+4wdc5lt3E41huP+GvQZJD38WLsgVp4iOtAjg= +github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.18.2/go.mod h1:/niFCtmuQNxqx9v8WAPq5qh7EH25U4BF6tjoyq9bObM= +github.com/aws/aws-sdk-go-v2/service/s3 v1.66.0 h1:xA6XhTF7PE89BCNHJbQi8VvPzcgMtmGC5dr8S8N7lHk= +github.com/aws/aws-sdk-go-v2/service/s3 v1.66.0/go.mod h1:cB6oAuus7YXRZhWCc1wIwPywwZ1XwweNp2TVAEGYeB8= +github.com/aws/aws-sdk-go-v2/service/secretsmanager v1.34.2 h1:Rrqru2wYkKQCS2IM5/JrgKUQIoNTqA6y/iuxkjzxC6M= +github.com/aws/aws-sdk-go-v2/service/secretsmanager v1.34.2/go.mod h1:QuCURO98Sqee2AXmqDNxKXYFm2OEDAVAPApMqO0Vqnc= github.com/aws/aws-sdk-go-v2/service/sns v1.31.3 h1:eSTEdxkfle2G98FE+Xl3db/XAXXVTJPNQo9K/Ar8oAI= github.com/aws/aws-sdk-go-v2/service/sns v1.31.3/go.mod h1:1dn0delSO3J69THuty5iwP0US2Glt0mx2qBBlI13pvw= github.com/aws/aws-sdk-go-v2/service/sqs v1.34.3 h1:Vjqy5BZCOIsn4Pj8xzyqgGmsSqzz7y/WXbN3RgOoVrc= github.com/aws/aws-sdk-go-v2/service/sqs v1.34.3/go.mod h1:L0enV3GCRd5iG9B64W35C4/hwsCB00Ib+DKVGTadKHI= -github.com/aws/aws-sdk-go-v2/service/sso v1.22.3 h1:Fv1vD2L65Jnp5QRsdiM64JvUM4Xe+E0JyVsRQKv6IeA= -github.com/aws/aws-sdk-go-v2/service/sso v1.22.3/go.mod h1:ooyCOXjvJEsUw7x+ZDHeISPMhtwI3ZCB7ggFMcFfWLU= -github.com/aws/aws-sdk-go-v2/service/ssooidc v1.26.4 h1:yiwVzJW2ZxZTurVbYWA7QOrAaCYQR72t0wrSBfoesUE= -github.com/aws/aws-sdk-go-v2/service/ssooidc v1.26.4/go.mod h1:0oxfLkpz3rQ/CHlx5hB7H69YUpFiI1tql6Q6Ne+1bCw= -github.com/aws/aws-sdk-go-v2/service/sts v1.30.3 h1:ZsDKRLXGWHk8WdtyYMoGNO7bTudrvuKpDKgMVRlepGE= -github.com/aws/aws-sdk-go-v2/service/sts v1.30.3/go.mod h1:zwySh8fpFyXp9yOr/KVzxOl8SRqgf/IDw5aUt9UKFcQ= +github.com/aws/aws-sdk-go-v2/service/ssm v1.55.2 h1:z6Pq4+jtKlhK4wWJGHRGwMLGjC1HZwAO3KJr/Na0tSU= +github.com/aws/aws-sdk-go-v2/service/ssm v1.55.2/go.mod h1:DSmu/VZzpQlAubWBbAvNpt+S4k/XweglJi4XaDGyvQk= +github.com/aws/aws-sdk-go-v2/service/sso v1.24.2 h1:bSYXVyUzoTHoKalBmwaZxs97HU9DWWI3ehHSAMa7xOk= +github.com/aws/aws-sdk-go-v2/service/sso v1.24.2/go.mod h1:skMqY7JElusiOUjMJMOv1jJsP7YUg7DrhgqZZWuzu1U= +github.com/aws/aws-sdk-go-v2/service/ssooidc v1.28.2 h1:AhmO1fHINP9vFYUE0LHzCWg/LfUWUF+zFPEcY9QXb7o= +github.com/aws/aws-sdk-go-v2/service/ssooidc v1.28.2/go.mod h1:o8aQygT2+MVP0NaV6kbdE1YnnIM8RRVQzoeUH45GOdI= +github.com/aws/aws-sdk-go-v2/service/sts v1.32.2 h1:CiS7i0+FUe+/YY1GvIBLLrR/XNGZ4CtM1Ll0XavNuVo= +github.com/aws/aws-sdk-go-v2/service/sts v1.32.2/go.mod h1:HtaiBI8CjYoNVde8arShXb94UbQQi9L4EMr6D+xGBwo= github.com/aws/smithy-go v1.13.5/go.mod h1:Tg+OJXh4MB2R/uN61Ko2f6hTZwB/ZYGOtib8J3gBHzA= -github.com/aws/smithy-go v1.20.3 h1:ryHwveWzPV5BIof6fyDvor6V3iUL7nTfiTKXHiW05nE= -github.com/aws/smithy-go v1.20.3/go.mod h1:krry+ya/rV9RDcV/Q16kpu6ypI4K2czasz0NC3qS14E= +github.com/aws/smithy-go v1.22.0 h1:uunKnWlcoL3zO7q+gG2Pk53joueEOsnNB28QdMsmiMM= +github.com/aws/smithy-go v1.22.0/go.mod h1:irrKGvNn1InZwb2d7fkIRNucdfwR8R+Ts3wxYa/cJHg= github.com/beorn7/perks v0.0.0-20180321164747-3a771d992973/go.mod h1:Dwedo/Wpr24TaqPxmxbtue+5NUziq4I4S80YR8gNf3Q= github.com/boombuler/barcode v1.0.0/go.mod h1:paBWMcWSl3LHKBqUq+rly7CNSldXjb2rDl3JlRe0mD8= github.com/boombuler/barcode v1.0.1/go.mod h1:paBWMcWSl3LHKBqUq+rly7CNSldXjb2rDl3JlRe0mD8= @@ -705,6 +737,8 @@ github.com/census-instrumentation/opencensus-proto v0.4.1/go.mod h1:4T9NM4+4Vw91 github.com/cespare/xxhash v1.1.0/go.mod h1:XrSqR1VqqWfGrhpAt58auRo0WTKS1nRRg3ghfAqPWnc= github.com/cespare/xxhash/v2 v2.1.1/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= github.com/cespare/xxhash/v2 v2.2.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= +github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= +github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= github.com/chzyer/logex v1.1.10/go.mod h1:+Ywpsq7O8HXn0nuIou7OrIPyXbp3wmkHB+jjWRnGsAI= github.com/chzyer/readline v0.0.0-20180603132655-2972be24d48e/go.mod h1:nSuG5e5PlCu98SY8svDHJxuZscDgtXS6KTTbou5AhLI= github.com/chzyer/test v0.0.0-20180213035817-a1ea475d72b1/go.mod h1:Q3SI9o4m/ZMnBNeIyt5eFwwo7qiLfzFZmjNmxjkiQlU= @@ -722,11 +756,14 @@ github.com/cncf/xds/go v0.0.0-20211011173535-cb28da3451f1/go.mod h1:eXthEFrGJvWH github.com/cncf/xds/go v0.0.0-20220314180256-7f1daf1720fc/go.mod h1:eXthEFrGJvWHgFFCl3hGmgk+/aYT6PnTQLykKQRLhEs= github.com/cncf/xds/go v0.0.0-20230105202645-06c439db220b/go.mod h1:eXthEFrGJvWHgFFCl3hGmgk+/aYT6PnTQLykKQRLhEs= github.com/cncf/xds/go v0.0.0-20230310173818-32f1caf87195/go.mod h1:eXthEFrGJvWHgFFCl3hGmgk+/aYT6PnTQLykKQRLhEs= +github.com/cncf/xds/go v0.0.0-20251022180443-0feb69152e9f h1:Y8xYupdHxryycyPlc9Y+bSQAYZnetRJ70VMVKm5CKI0= +github.com/cncf/xds/go v0.0.0-20251022180443-0feb69152e9f/go.mod h1:HlzOvOjVBOfTGSRXRyY0OiCS/3J1akRGQQpRO/7zyF4= github.com/coreos/go-systemd v0.0.0-20181012123002-c6f51f82210d/go.mod h1:F5haX7vjVVG0kc13fIWeqUViNPyEJxv/OmvnBo0Yme4= github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM= +github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/decred/dcrd/crypto/blake256 v1.0.1/go.mod h1:2OfgNZ5wDpcsFmHmCK5gZTPcCXqlm2ArzUIkw9czNJo= github.com/decred/dcrd/dcrec/secp256k1/v4 v4.2.0 h1:8UrgZ3GkP4i/CLijOJx79Yu+etlyjdBU4sfcs2WYQMs= github.com/decred/dcrd/dcrec/secp256k1/v4 v4.2.0/go.mod h1:v57UDF4pDQJcEfFUCRop3lJL149eHGSe9Jvczhzjo/0= @@ -744,10 +781,18 @@ github.com/envoyproxy/go-control-plane v0.9.10-0.20210907150352-cf90f659a021/go. github.com/envoyproxy/go-control-plane v0.10.2-0.20220325020618-49ff273808a1/go.mod h1:KJwIaB5Mv44NWtYuAOFCVOjcI94vtpEz2JU/D2v6IjE= github.com/envoyproxy/go-control-plane v0.10.3/go.mod h1:fJJn/j26vwOu972OllsvAgJJM//w9BV6Fxbg2LuVd34= github.com/envoyproxy/go-control-plane v0.11.0/go.mod h1:VnHyVMpzcLvCFt9yUz1UnCwHLhwx1WguiVDV7pTG/tI= +github.com/envoyproxy/go-control-plane v0.13.5-0.20251024222203-75eaa193e329 h1:K+fnvUM0VZ7ZFJf0n4L/BRlnsb9pL/GuDG6FqaH+PwM= +github.com/envoyproxy/go-control-plane v0.13.5-0.20251024222203-75eaa193e329/go.mod h1:Alz8LEClvR7xKsrq3qzoc4N0guvVNSS8KmSChGYr9hs= +github.com/envoyproxy/go-control-plane/envoy v1.35.0 h1:ixjkELDE+ru6idPxcHLj8LBVc2bFP7iBytj353BoHUo= +github.com/envoyproxy/go-control-plane/envoy v1.35.0/go.mod h1:09qwbGVuSWWAyN5t/b3iyVfz5+z8QWGrzkoqm/8SbEs= +github.com/envoyproxy/go-control-plane/ratelimit v0.1.0 h1:/G9QYbddjL25KvtKTv3an9lx6VBE2cnb8wp1vEGNYGI= +github.com/envoyproxy/go-control-plane/ratelimit v0.1.0/go.mod h1:Wk+tMFAFbCXaJPzVVHnPgRKdUdwW/KdbRt94AzgRee4= github.com/envoyproxy/protoc-gen-validate v0.1.0/go.mod h1:iSmxcyjqTsJpI2R4NaDN7+kN2VEUnK/pcBlmesArF7c= github.com/envoyproxy/protoc-gen-validate v0.6.7/go.mod h1:dyJXwwfPK2VSqiB9Klm1J6romD608Ba7Hij42vrOBCo= github.com/envoyproxy/protoc-gen-validate v0.9.1/go.mod h1:OKNgG7TCp5pF4d6XftA0++PMirau2/yoOwVac3AbF2w= github.com/envoyproxy/protoc-gen-validate v0.10.0/go.mod h1:DRjgyB0I43LtJapqN6NiRwroiAU2PaFuvk/vjgh61ss= +github.com/envoyproxy/protoc-gen-validate v1.2.1 h1:DEo3O99U8j4hBFwbJfrz9VtgcDfUKS7KJ7spH3d86P8= +github.com/envoyproxy/protoc-gen-validate v1.2.1/go.mod h1:d/C80l/jxXLdfEIhX1W2TmLfsJ31lvEjwamM4DxlWXU= github.com/felixge/httpsnoop v1.0.4 h1:NFTV2Zj1bL4mc9sqWACXbQFVBBg2W3GPvqp8/ESS2Wg= github.com/felixge/httpsnoop v1.0.4/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U= github.com/flynn/go-shlex v0.0.0-20150515145356-3f9db97f8568/go.mod h1:xEzjJPgXI435gkrCt3MPfRiAkVrwSbHsst4LCFVfpJc= @@ -771,11 +816,13 @@ github.com/go-fonts/stix v0.1.0/go.mod h1:w/c1f0ldAUlJmLBvlbkvVXLAD+tAMqobIIQpmn github.com/go-gl/glfw v0.0.0-20190409004039-e6da0acd62b1/go.mod h1:vR7hzQXu2zJy9AVAgeJqvqgH9Q5CA+iKCZ2gyEVpxRU= github.com/go-gl/glfw/v3.3/glfw v0.0.0-20191125211704-12ad95a8df72/go.mod h1:tQ2UAYgL5IevRw8kRxooKSPJfGvJ9fJQFa0TUsXzTg8= github.com/go-gl/glfw/v3.3/glfw v0.0.0-20200222043503-6f7a984d4dc4/go.mod h1:tQ2UAYgL5IevRw8kRxooKSPJfGvJ9fJQFa0TUsXzTg8= +github.com/go-jose/go-jose/v4 v4.1.3 h1:CVLmWDhDVRa6Mi/IgCgaopNosCaHz7zrMeF9MlZRkrs= +github.com/go-jose/go-jose/v4 v4.1.3/go.mod h1:x4oUasVrzR7071A4TnHLGSPpNOm2a21K9Kf04k1rs08= github.com/go-latex/latex v0.0.0-20210118124228-b3d85cf34e07/go.mod h1:CO1AlKB2CSIqUrmQPqA0gdRIlnLEY0gK5JGjh37zN5U= github.com/go-latex/latex v0.0.0-20210823091927-c0d11ff05a81/go.mod h1:SX0U8uGpxhq9o2S/CELCSUxEWWAuoCUcVCQWv7G2OCk= github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A= -github.com/go-logr/logr v1.4.1 h1:pKouT5E8xu9zeFC39JXRDukb6JFQPXM5p5I91188VAQ= -github.com/go-logr/logr v1.4.1/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY= +github.com/go-logr/logr v1.4.3 h1:CjnDlHq8ikf6E492q6eKboGOC0T8CDaOvkHCIg8idEI= +github.com/go-logr/logr v1.4.3/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY= github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag= github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE= github.com/go-ole/go-ole v1.2.5/go.mod h1:pprOEPIfldk/42T2oK7lQ4v4JSDwmV0As9GaiUsvbm0= @@ -789,9 +836,8 @@ github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572/go.mod h1:9Pwr4 github.com/go-viper/mapstructure/v2 v2.2.1 h1:ZAaOCxANMuZx5RCeg0mBdEZk7DZasvvZIxtHqx8aGss= github.com/go-viper/mapstructure/v2 v2.2.1/go.mod h1:oJDH3BJKyqBA2TXFhDsKDGDTlndYOZ6rGS0BRZIxGhM= github.com/goccy/go-json v0.9.11/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I= +github.com/goccy/go-json v0.10.2 h1:CrxCmQqYDkv1z7lO7Wbh2HN93uovUHgrECaO5ZrCXAU= github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I= -github.com/goccy/go-json v0.10.5 h1:Fq85nIqj+gXn/S5ahsiTlK3TmC85qgirsdTP/+DeaC4= -github.com/goccy/go-json v0.10.5/go.mod h1:oq7eo15ShAhp70Anwd5lgX2pLfOS3QCiwU/PULtXL6M= github.com/gogo/protobuf v1.1.1/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7atdtwQ= github.com/golang-jwt/jwt/v4 v4.4.2/go.mod h1:m21LjoU+eqJr34lmDMbreY2eSTRJ1cv77w39/MY0Ch0= github.com/golang-jwt/jwt/v4 v4.5.1 h1:JdqV9zKUdtaa9gdPlywC3aeoEsR681PlKC+4F5gQgeo= @@ -805,8 +851,9 @@ github.com/golang/glog v1.1.0/go.mod h1:pfYeQZ3JWZoXTV5sFc986z3HTpwQs9At6P4ImfuP github.com/golang/groupcache v0.0.0-20190702054246-869f871628b6/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= github.com/golang/groupcache v0.0.0-20191227052852-215e87163ea7/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= github.com/golang/groupcache v0.0.0-20200121045136-8c9f03a8e57e/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= -github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da h1:oI5xCqsCo564l8iNU+DwB5epxmsaqB+rhGL0m5jtYqE= github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= +github.com/golang/groupcache v0.0.0-20241129210726-2c02b8208cf8 h1:f+oWsMOmNPc8JmEHVZIycC7hBoQxHH9pNKQORJNozsQ= +github.com/golang/groupcache v0.0.0-20241129210726-2c02b8208cf8/go.mod h1:wcDNUvekVysuuOpQKo3191zZyTpiI6se1N1ULghS0sw= github.com/golang/lint v0.0.0-20180702182130-06c8688daad7/go.mod h1:tluoj9z5200jBnyusfRPU2LqT6J+DAorxEvtC7LHB+E= github.com/golang/mock v1.1.1/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A= github.com/golang/mock v1.2.0/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A= @@ -857,8 +904,8 @@ github.com/google/go-cmp v0.5.6/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/ github.com/google/go-cmp v0.5.7/go.mod h1:n+brtR0CgQNWTVd5ZUFpTBC8YFBDLK/h/bpaJ8/DtOE= github.com/google/go-cmp v0.5.8/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= -github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= -github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= +github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= +github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= github.com/google/go-github v17.0.0+incompatible/go.mod h1:zLgOLi98H3fifZn+44m+umXrS52loVEgC2AApnigrVQ= github.com/google/go-querystring v1.0.0/go.mod h1:odCYkC5MyYFN7vkCjXpyrEuKhc/BUO6wN/zVPAxq5ck= github.com/google/gops v0.3.23 h1:OjsHRINl5FiIyTc8jivIg4UN0GY6Nh32SL8KRbl8GQo= @@ -868,8 +915,9 @@ github.com/google/martian v2.1.0+incompatible/go.mod h1:9I4somxYTbIHy5NJKHRl3wXi github.com/google/martian/v3 v3.0.0/go.mod h1:y5Zk1BBys9G+gd6Jrk0W3cC1+ELVxBWuIGO+w/tUAp0= github.com/google/martian/v3 v3.1.0/go.mod h1:y5Zk1BBys9G+gd6Jrk0W3cC1+ELVxBWuIGO+w/tUAp0= github.com/google/martian/v3 v3.2.1/go.mod h1:oBOf6HBosgwRXnUGWUB05QECsc6uvmMiJ3+6W4l/CUk= -github.com/google/martian/v3 v3.3.2 h1:IqNFLAmvJOgVlpdEBiQbDc2EwKW77amAycfTuWKdfvw= github.com/google/martian/v3 v3.3.2/go.mod h1:oBOf6HBosgwRXnUGWUB05QECsc6uvmMiJ3+6W4l/CUk= +github.com/google/martian/v3 v3.3.3 h1:DIhPTQrbPkgs2yJYdXU/eNACCG5DVQjySNRNlflZ9Fc= +github.com/google/martian/v3 v3.3.3/go.mod h1:iEPrYcgCF7jA9OtScMFQyAlZZ4YXTKEtJ1E6RWzmBA0= github.com/google/pprof v0.0.0-20181206194817-3ea8567a2e57/go.mod h1:zfwlbNMJ+OItoe0UupaVj+oy1omPYYDuagoSzA8v9mc= github.com/google/pprof v0.0.0-20190515194954-54271f7e092f/go.mod h1:zfwlbNMJ+OItoe0UupaVj+oy1omPYYDuagoSzA8v9mc= github.com/google/pprof v0.0.0-20191218002539-d4f498aebedc/go.mod h1:ZgVRPoUq/hfqzAqh7sHMqb3I9Rq5C59dIz2SbBwJ4eM= @@ -891,8 +939,8 @@ github.com/google/renameio v0.1.0/go.mod h1:KWCgfxg9yswjAJkECMjeO8J8rahYeXnNhOm4 github.com/google/s2a-go v0.1.0/go.mod h1:OJpEgntRZo8ugHpF9hkoLJbS5dSI20XZeXJ9JVywLlM= github.com/google/s2a-go v0.1.3/go.mod h1:Ej+mSEMGRnqRzjc7VtF+jdBwYG5fuJfiZ8ELkjEwM0A= github.com/google/s2a-go v0.1.4/go.mod h1:Ej+mSEMGRnqRzjc7VtF+jdBwYG5fuJfiZ8ELkjEwM0A= -github.com/google/s2a-go v0.1.7 h1:60BLSyTrOV4/haCDW4zb1guZItoSq8foHCXrAnjBo/o= -github.com/google/s2a-go v0.1.7/go.mod h1:50CgR4k1jNlWBu4UfS4AcfhVe1r6pdZPygJ3R8F0Qdw= +github.com/google/s2a-go v0.1.8 h1:zZDs9gcbt9ZPLV0ndSyQk6Kacx2g/X+SKYovpnz3SMM= +github.com/google/s2a-go v0.1.8/go.mod h1:6iNWHTpQ+nfNRN5E00MSdfDwVesa8hhS32PhPO8deJA= github.com/google/uuid v1.1.2/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= @@ -902,8 +950,8 @@ github.com/googleapis/enterprise-certificate-proxy v0.1.0/go.mod h1:17drOmN3MwGY github.com/googleapis/enterprise-certificate-proxy v0.2.0/go.mod h1:8C0jb7/mgJe/9KK8Lm7X9ctZC2t60YyIpYEI16jx0Qg= github.com/googleapis/enterprise-certificate-proxy v0.2.1/go.mod h1:AwSRAtLfXpU5Nm3pW+v7rGDHp09LsPtGY9MduiEsR9k= github.com/googleapis/enterprise-certificate-proxy v0.2.3/go.mod h1:AwSRAtLfXpU5Nm3pW+v7rGDHp09LsPtGY9MduiEsR9k= -github.com/googleapis/enterprise-certificate-proxy v0.3.2 h1:Vie5ybvEvT75RniqhfFxPRy3Bf7vr3h0cechB90XaQs= -github.com/googleapis/enterprise-certificate-proxy v0.3.2/go.mod h1:VLSiSSBs/ksPL8kq3OBOQ6WRI2QnaFynd1DCjZ62+V0= +github.com/googleapis/enterprise-certificate-proxy v0.3.4 h1:XYIDZApgAnrN1c855gTgghdIA6Stxb52D5RnLI1SLyw= +github.com/googleapis/enterprise-certificate-proxy v0.3.4/go.mod h1:YKe7cfqYXjKGpGvmSg28/fFvhNzinZQm8DGnaburhGA= github.com/googleapis/gax-go v2.0.0+incompatible/go.mod h1:SFVmujtThgffbyetf+mdk2eWhX2bMyUtNHzFKcPA9HY= github.com/googleapis/gax-go/v2 v2.0.3/go.mod h1:LLvjysVCY1JZeum8Z6l8qUty8fiNwE08qbEPm1M08qg= github.com/googleapis/gax-go/v2 v2.0.4/go.mod h1:0Wqv26UfaUD9n4G6kQubkQ+KchISgw+vpHVxEJEs9eg= @@ -920,8 +968,8 @@ github.com/googleapis/gax-go/v2 v2.7.1/go.mod h1:4orTrqY6hXxxaUL4LHIPl6lGo8vAE38 github.com/googleapis/gax-go/v2 v2.8.0/go.mod h1:4orTrqY6hXxxaUL4LHIPl6lGo8vAE38/qKbhSAKP6QI= github.com/googleapis/gax-go/v2 v2.10.0/go.mod h1:4UOEnMCrxsSqQ940WnTiD6qJ63le2ev3xfyagutxiPw= github.com/googleapis/gax-go/v2 v2.11.0/go.mod h1:DxmR61SGKkGLa2xigwuZIQpkCI2S5iydzRfb3peWZJI= -github.com/googleapis/gax-go/v2 v2.12.3 h1:5/zPPDvw8Q1SuXjrqrZslrqT7dL/uJT2CQii/cLCKqA= -github.com/googleapis/gax-go/v2 v2.12.3/go.mod h1:AKloxT6GtNbaLm8QTNSidHUVsHYcBHwWRvkNFJUQcS4= +github.com/googleapis/gax-go/v2 v2.13.0 h1:yitjD5f7jQHhyDsnhKEBU52NdvvdSeGzlAnDPT0hH1s= +github.com/googleapis/gax-go/v2 v2.13.0/go.mod h1:Z/fvTZXF8/uw7Xu5GuslPw+bplx6SS338j1Is2S+B7A= github.com/googleapis/go-type-adapters v1.0.0/go.mod h1:zHW75FOG2aur7gAO2B+MLby+cLsWGBF62rFAi7WjWO4= github.com/googleapis/google-cloud-go-testing v0.0.0-20200911160855-bcd43fbb19e8/go.mod h1:dvDLG8qkwmyD9a/MJJN3XJcT3xFxOKAvTZGvuZmac9g= github.com/gopherjs/gopherjs v0.0.0-20181017120253-0766667cb4d1/go.mod h1:wJfORRmW1u3UXTncJ5qlYoELFm8eSnnEO6hX4iZ3EWY= @@ -957,8 +1005,9 @@ github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa02 github.com/kr/fs v0.1.0/go.mod h1:FFnZGqtBN9Gxj7eW1uZ42v5BccTP0vu6NEaFoC2HwRg= github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI= -github.com/kr/pretty v0.3.0 h1:WgNl7dwNpEZ6jJ9k1snq4pZsg7DOEN8hP9Xw0Tsjwk0= github.com/kr/pretty v0.3.0/go.mod h1:640gp4NfQd8pI5XOwp5fnNeVWj67G7CFk/SaSQn7NBk= +github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= +github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= github.com/kr/pty v1.1.3/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= @@ -1022,8 +1071,11 @@ github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pkg/sftp v1.10.1/go.mod h1:lYOWFsE0bwd1+KfKJaKeuokY15vzFx25BLbzYYoAxZI= github.com/pkg/sftp v1.13.1/go.mod h1:3HaPG6Dq1ILlpPZRO0HVMrsydcdLt6HRDccSgb87qRg= -github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/planetscale/vtprotobuf v0.6.1-0.20240319094008-0393e58bdf10 h1:GFCKgmp0tecUJ0sJuv4pzYCqS9+RGSn52M3FUwPs+uo= +github.com/planetscale/vtprotobuf v0.6.1-0.20240319094008-0393e58bdf10/go.mod h1:t/avpk3KcrXxUnYOhZhMXJlSEyie6gQbtLq5NM3loB8= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U= +github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/prometheus/client_golang v0.8.0/go.mod h1:7SWBe2y4D6OKWSNQJUaRYU/AaXPKyh/dDVn+NZz0KFw= github.com/prometheus/client_model v0.0.0-20180712105110-5c3871d89910/go.mod h1:MbSGuTsp3dbXC40dX6PRTWyKYBIrTGTE9sqQNg2J8bo= github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= @@ -1041,8 +1093,9 @@ github.com/richardlehane/msoleps v1.0.3/go.mod h1:BWev5JBpU9Ko2WAgmZEuiz4/u3ZYTK github.com/rogpeppe/fastuuid v1.2.0/go.mod h1:jVj6XXZzXRy/MSR5jhDC/2q6DgLz+nrA6LYCDYWNEvQ= github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4= github.com/rogpeppe/go-internal v1.6.1/go.mod h1:xXDCJY+GAPziupqXw64V24skbSoqbTEfhy4qGm1nDQc= -github.com/rogpeppe/go-internal v1.9.0 h1:73kH8U+JUqXU8lRuOHeVHaa/SZPifC7BkcraZVejAe8= github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs= +github.com/rogpeppe/go-internal v1.14.1 h1:UQB4HGPB6osV0SQTLymcB4TgvyWu6ZyliaW0tI/otEQ= +github.com/rogpeppe/go-internal v1.14.1/go.mod h1:MaRKkUm5W0goXpeCfT7UZI6fk/L7L7so1lCWt35ZSgc= github.com/russross/blackfriday v1.5.2/go.mod h1:JO/DiYxRf+HjHt06OyowR9PTA263kcR/rfWxYHBV53g= github.com/ruudk/golang-pdf417 v0.0.0-20181029194003-1af4ab5afa58/go.mod h1:6lfFZQK844Gfx8o5WFuvpxWRwnSoipWe/p622j1v06w= github.com/ruudk/golang-pdf417 v0.0.0-20201230142125-a7e3863a1245/go.mod h1:pQAZKsJ8yyVxGRWYNEm9oFB8ieLgKFnamEyDmSA0BRk= @@ -1076,6 +1129,8 @@ github.com/spaolacci/murmur3 v0.0.0-20180118202830-f09979ecbc72/go.mod h1:JwIasO github.com/spf13/afero v1.3.3/go.mod h1:5KUK8ByomD5Ti5Artl0RtHeI5pTF7MIDuXL3yY520V4= github.com/spf13/afero v1.6.0/go.mod h1:Ai8FlHk4v/PARR026UzYexafAt9roJ7LcLMAmO6Z93I= github.com/spf13/afero v1.9.2/go.mod h1:iUV7ddyEEZPO5gA3zD4fJt6iStLlL+Lg4m2cihcDf8Y= +github.com/spiffe/go-spiffe/v2 v2.6.0 h1:l+DolpxNWYgruGQVV0xsfeya3CsC7m8iBzDnMpsbLuo= +github.com/spiffe/go-spiffe/v2 v2.6.0/go.mod h1:gm2SeUoMZEtpnzPNs2Csc0D/gX33k1xIx7lEzqblHEs= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= @@ -1090,17 +1145,17 @@ github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= -github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= -github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= +github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= github.com/tarm/serial v0.0.0-20180830185346-98f6abe2eb07/go.mod h1:kDXzergiv9cbyO7IOYJZWg1U88JhDg3PB6klq9Hg2pA= github.com/tklauser/go-sysconf v0.3.9/go.mod h1:11DU/5sG7UexIrp/O6g35hrWzu0JxlwQ3LSFUzyeuhs= github.com/tklauser/numcpus v0.3.0/go.mod h1:yFGUr7TUHQRAhyqBcEg0Ge34zDBAsIvJJcyE6boqnA8= github.com/viant/aerospike v0.2.11-0.20241108195857-ed524b97800d h1:IRmoMmrWqkHDBy0tk9mbHRDK7+ynn0Gzwl+9WIiAtNs= github.com/viant/aerospike v0.2.11-0.20241108195857-ed524b97800d/go.mod h1:eRBywl0oTDM/oGhGLUeJjnC7XzmkTGuW9/og5YFy0K0= -github.com/viant/afs v1.26.2 h1:rOs/iFxFlEndhIRATJVXlNWhVU0cGdRQAGVTVJPdsc0= -github.com/viant/afs v1.26.2/go.mod h1:rScbFd9LJPGTM8HOI8Kjwee0AZ+MZMupAvFpPg+Qdj4= -github.com/viant/afsc v1.9.1 h1:BIus7fYyjM+MDgKuAzCBfoV4oVy2xTVhuFsQKUCPvkQ= -github.com/viant/afsc v1.9.1/go.mod h1:FA/xVjaMM10qGByabP8anTVMH6N4eUsAeWm5xcEZJJA= +github.com/viant/afs v1.29.0 h1:ndnn+PBQt5ep/bE1m5OvIvMjpoCCZbtl/UlJEubT9kE= +github.com/viant/afs v1.29.0/go.mod h1:rScbFd9LJPGTM8HOI8Kjwee0AZ+MZMupAvFpPg+Qdj4= +github.com/viant/afsc v1.16.0 h1:/kOH/flNwme6h3oFrU/KPnMHkhbCZxQncTf1GSQIlBQ= +github.com/viant/afsc v1.16.0/go.mod h1:Z6fP3VcmzS8Sg2lowctR6KkVEX7XxJ8aNaoHqhUiZkY= github.com/viant/assertly v0.4.8/go.mod h1:aGifi++jvCrUaklKEKT0BU95igDNaqkvz+49uaYMPRU= github.com/viant/assertly v0.9.0/go.mod h1:aGifi++jvCrUaklKEKT0BU95igDNaqkvz+49uaYMPRU= github.com/viant/assertly v0.9.1-0.20220620174148-bab013f93a60 h1:VFJvCOHKXv4IqX8rJwn1otpHWQGgMDv2bXtAPgEzndM= @@ -1127,44 +1182,42 @@ github.com/viant/govalidator v0.3.1 h1:V7f/KgfzbP8fVDc+Kj+jyPvfXxMr2N1x7srOlDV6l github.com/viant/govalidator v0.3.1/go.mod h1:D35Dwx0R8rR1knRxhlseoYvOkiqo24kpMg1/o977i9Y= github.com/viant/igo v0.2.0 h1:ygWmTCinnGPaeV7omJLiyneOpzYZ5kiw7oYz7mUJZVQ= github.com/viant/igo v0.2.0/go.mod h1:7V6AWsLhKWeGzXNTNH3AZiIEKa0m33DrQbdWtapsI74= -github.com/viant/jsonrpc v0.7.2 h1:FUzhfFN76E09ZbQOxReFOyPhsxYhE0fjWzPhattR9Dk= -github.com/viant/jsonrpc v0.7.2/go.mod h1:LW2l5/H4KkGCsx2ktPX59iUlycw85ZlBcRuK/WYWBX8= -github.com/viant/mcp v0.4.3 h1:ykQ2XyS2l5xrxHY5peJgIWoH+n8ZSpiSifnO/UH6/3I= -github.com/viant/mcp v0.4.3/go.mod h1:3SnILtYVIT8PIWICMyzP9KfhepawoFRv+//FBU/hc7c= -github.com/viant/mcp-protocol v0.4.4 h1:jKuCHvXeNof1Of1UfUyJkrSSNfOBiN4pXKWv3J2NwFM= -github.com/viant/mcp-protocol v0.4.4/go.mod h1:EL4NY7yW2gge+XLorgJA7PIazQX3x4ZkutYihwBwINs= -github.com/viant/parsly v0.3.3-0.20240717150634-e1afaedb691b h1:3q166tV28yFdbFV+tXXjH7ViKAmgAgGdoWzMtvhQv28= -github.com/viant/parsly v0.3.3-0.20240717150634-e1afaedb691b/go.mod h1:85fneXJbErKMGhSQto3A5ElTQCwl3t74U9cSV0waBHw= +github.com/viant/jsonrpc v0.17.0 h1:LZpe2H8tFUmWnvevDs2t6V7Cz7LzOGmpP8WcZipuXZE= +github.com/viant/jsonrpc v0.17.0/go.mod h1:b214Lo4zBwLqbu6Tf2bRlgQkFfPMBW5ap4qS+I3zcJ8= +github.com/viant/mcp v0.11.0 h1:dMcf5V5dPu3Ybpz7Q1nxj2fGmP/OKE1iM6MW3564eng= +github.com/viant/mcp v0.11.0/go.mod h1:mBSxAq6WvGpKRtWv3jknp+QU/oqjhov2Ab3nM9bp0F0= +github.com/viant/mcp-protocol v0.11.0 h1:22IuTTlq0L8l08z23TYRvHM/j19gp9UPExQdzwTuxsY= +github.com/viant/mcp-protocol v0.11.0/go.mod h1:EJPomVw6jnI+4Aa2ONYC3WTvApiF0YeQIiaaEpA54ec= +github.com/viant/parsly v0.3.3 h1:7ytgfLOG4Ils+wviGacWxRD0gAUvVEH/iGsSE+UI8YM= +github.com/viant/parsly v0.3.3/go.mod h1:85fneXJbErKMGhSQto3A5ElTQCwl3t74U9cSV0waBHw= github.com/viant/pgo v0.11.0 h1:PNuYVhwTfyrAHGBO6lxaMFuHP4NkjKV8ULecz3OWk8c= github.com/viant/pgo v0.11.0/go.mod h1:MFzHmkRFZlciugEgUvpl/3grK789PBSH4dUVSLOSo+Q= github.com/viant/scy v0.24.0 h1:KAC3IUARkQxTNSuwBK2YhVBJMOOLN30YaLKHbbuSkMU= github.com/viant/scy v0.24.0/go.mod h1:7uNRS67X45YN+JqTLCcMEhehffVjqrejULEDln9p0Ao= -github.com/viant/sqlparser v0.8.1 h1:nbcTecMtW7ROk5aNB5/BWUxnduepRPOkhVo9RWxI1Ns= -github.com/viant/sqlparser v0.8.1/go.mod h1:2QRGiGZYk2/pjhORGG1zLVQ9JO+bXFhqIVi31mkCRPg= -github.com/viant/sqlx v0.16.6 h1:3/D1/c3E8cMaUWTUBW56Gg/1vW4QMMWm42HkSAbzSZQ= -github.com/viant/sqlx v0.16.6/go.mod h1:dizufL+nTNqDCpivUnE2HqtddTp2TdA6WFghGfZo11c= -github.com/viant/sqlx v0.17.6 h1:6uMZVWk+WJl/y8coEh4F4mqbTHbtzWkLVEQdrk+m7sE= -github.com/viant/sqlx v0.17.6/go.mod h1:dizufL+nTNqDCpivUnE2HqtddTp2TdA6WFghGfZo11c= -github.com/viant/structology v0.6.1 h1:Forza+RF/1tmlQFk9ABNhu+IQ8vMAqbYM6FOsYtGh9E= -github.com/viant/structology v0.6.1/go.mod h1:63XfkzUyNw7wdi99HJIsH2Rg3d5AOumqbWLUYytOkxU= -github.com/viant/structql v0.5.2 h1:0dAratszxC6AD/TNaV8BnLQQprNO5GJHaKjmszrIoeY= -github.com/viant/structql v0.5.2/go.mod h1:nm9AYnAuSKH7b7pG+dKVxbQrr1Mgp1CQEMvUwwkE+I8= -github.com/viant/tagly v0.2.2 h1:qqb4Dov83i7nl7Gewph/lLaYAF8MKv0N7y34scgRNmE= -github.com/viant/tagly v0.2.2/go.mod h1:vV8QgJkhug+X+qyKds8av0fhjD+4u7IhNtowL1KGQ5A= +github.com/viant/sqlparser v0.11.1-0.20260224194657-0470849e3588 h1:bnVgWzZzuz2pTa54e7YozHjYNFSapfU3MSklyMkO+Ag= +github.com/viant/sqlparser v0.11.1-0.20260224194657-0470849e3588/go.mod h1:2QRGiGZYk2/pjhORGG1zLVQ9JO+bXFhqIVi31mkCRPg= +github.com/viant/sqlx v0.21.0 h1:Lx5KXmzfSjSvZZX5P0Ua9kFGvAmCxAjLOPe9pQA7VmY= +github.com/viant/sqlx v0.21.0/go.mod h1:woTOwNiqvt6SqkI+5nyzlixcRTTV0IvLZUTberqb8mo= +github.com/viant/structology v0.8.0 h1:WKdK67l+O1eqsubn8PWMhWcgspUGJ22SgJxUMfiRgqE= +github.com/viant/structology v0.8.0/go.mod h1:Fnm1DyR4gfyPbnhBMkQB5lR6/isYDnncBFO1nCxxmqs= +github.com/viant/structql v0.5.4 h1:bMdcOpzU8UMoe5OBcyJVRxLAndvU1oj3ysvPUgBckCI= +github.com/viant/structql v0.5.4/go.mod h1:nm9AYnAuSKH7b7pG+dKVxbQrr1Mgp1CQEMvUwwkE+I8= +github.com/viant/tagly v0.3.0 h1:Y8IckveeSrroR8yisq4MBdxhcNqf4v8II01uCpamh4E= +github.com/viant/tagly v0.3.0/go.mod h1:PauQQkHmAvL5lFGr4gIgi+PE0aUPggBIBYN34sX2Oes= github.com/viant/toolbox v0.24.0/go.mod h1:OxMCG57V0PXuIP2HNQrtJf2CjqdmbrOx5EkMILuUhzM= github.com/viant/toolbox v0.34.5/go.mod h1:OxMCG57V0PXuIP2HNQrtJf2CjqdmbrOx5EkMILuUhzM= github.com/viant/toolbox v0.37.0 h1:+zwSdbQh6I6ZEyxokQJr+1gQKbLEw6erc+Av5dwKtLU= github.com/viant/toolbox v0.37.0/go.mod h1:OxMCG57V0PXuIP2HNQrtJf2CjqdmbrOx5EkMILuUhzM= -github.com/viant/velty v0.2.1-0.20230927172116-ba56497b5c85 h1:zKk+6hqUipkJXCPCHyFXzGtil1sfh80r6UZmloBNEDo= -github.com/viant/velty v0.2.1-0.20230927172116-ba56497b5c85/go.mod h1:Q/UXviI2Nli8WROEpYd/BELMCSvnulQeyNrbPmMiS/Y= -github.com/viant/x v0.3.0 h1:/3A0z/uySGxMo6ixH90VAcdjI00w5e3REC1zg5hzhJA= -github.com/viant/x v0.3.0/go.mod h1:54jP3qV+nnQdNDaWxEwGTAAzCu9sx9er9htiwTW/Mcw= -github.com/viant/xdatly v0.5.4-0.20250806192028-819cadf93282 h1:CqRQGsior7arN1lQA11oCoWdC/LZv1ObhCOGpdwvR3k= -github.com/viant/xdatly v0.5.4-0.20250806192028-819cadf93282/go.mod h1:lZKZHhVdCZ3U9TU6GUFxKoGN3dPtqt2HkDYzJPq5CEs= +github.com/viant/velty v0.4.0 h1:eesQES/vCpcoPbM+gQLUBuLEL2sEO+A6s6lPpl8eKc4= +github.com/viant/velty v0.4.0/go.mod h1:Q/UXviI2Nli8WROEpYd/BELMCSvnulQeyNrbPmMiS/Y= +github.com/viant/x v0.4.0 h1:n2xuxQdw4lYtMdi59IAQEZHPioNT9InENGGbapyz+P4= +github.com/viant/x v0.4.0/go.mod h1:1TvsnpZFqI9dYVzIkaSYJyJ/UkfxW7fnk0YFafWXrPg= +github.com/viant/xdatly v0.5.4-0.20251113181159-0ac8b8b0ff3a h1:7CLO2LjVnFgOwN0FL3Q4y5NrD7DpclS21AiW6tDLIc8= +github.com/viant/xdatly v0.5.4-0.20251113181159-0ac8b8b0ff3a/go.mod h1:lZKZHhVdCZ3U9TU6GUFxKoGN3dPtqt2HkDYzJPq5CEs= github.com/viant/xdatly/extension v0.0.0-20231013204918-ecf3c2edf259 h1:9Yry3PUBDzc4rWacOYvAq/TKrTV0agvMF0gwm2gaoHI= github.com/viant/xdatly/extension v0.0.0-20231013204918-ecf3c2edf259/go.mod h1:fb8YgbVadk8X5ZLz49LWGzWmQlZd7Y/I5wE0ru44bIo= -github.com/viant/xdatly/handler v0.0.0-20250806192028-819cadf93282 h1:oNhkNyC6bRBifxWLyd7MTEFmCMwfg1LaAjKAmubrWCM= -github.com/viant/xdatly/handler v0.0.0-20250806192028-819cadf93282/go.mod h1:OeV4sVatklNs31nFnZtSp7lEvKJRoVJbH5opNRmRPg0= +github.com/viant/xdatly/handler v0.0.0-20251208172928-dd34b7f09fd5 h1:CrT0HTlQul8FoGN0peylVczAOUEXKVqRAiB35ypRNHY= +github.com/viant/xdatly/handler v0.0.0-20251208172928-dd34b7f09fd5/go.mod h1:OeV4sVatklNs31nFnZtSp7lEvKJRoVJbH5opNRmRPg0= github.com/viant/xdatly/types/core v0.0.0-20250307183722-8c84fc717b52 h1:G+e1MMDxQXUPPlAXVNlRqSLTLra7udGQZUu3hnr0Y8M= github.com/viant/xdatly/types/core v0.0.0-20250307183722-8c84fc717b52/go.mod h1:LJN2m8xJjtYNCvyvNrVanJwvzj8+hYCuPswL8H4qRG0= github.com/viant/xdatly/types/custom v0.0.0-20240801144911-4c2bfca4c23a h1:jecH7mH63gj1zJwD18SdvSHM9Ttr9FEOnhHkYfkCNkI= @@ -1205,18 +1258,24 @@ go.opencensus.io v0.22.5/go.mod h1:5pWMHQbX5EPX2/62yrJeAkowc+lfs/XD7Uxpq3pI6kk= go.opencensus.io v0.23.0/go.mod h1:XItmlyltB5F7CS4xOC1DcqMoFqwtC6OG2xF7mCv7P7E= go.opencensus.io v0.24.0 h1:y73uSU6J157QMP2kn2r30vwW1A2W2WFwSCGnAVxeaD0= go.opencensus.io v0.24.0/go.mod h1:vNK8G9p7aAivkbmorf4v+7Hgx+Zs0yY+0fOtgBfjQKo= -go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.49.0 h1:4Pp6oUg3+e/6M4C0A/3kJ2VYa++dsWVTtGgLVj5xtHg= -go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.49.0/go.mod h1:Mjt1i1INqiaoZOMGR1RIUJN+i3ChKoFRqzrRQhlkbs0= -go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.49.0 h1:jq9TW8u3so/bN+JPT166wjOI6/vQPF6Xe7nMNIltagk= -go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.49.0/go.mod h1:p8pYQP+m5XfbZm9fxtSKAbM6oIllS7s2AfxrChvc7iw= -go.opentelemetry.io/otel v1.24.0 h1:0LAOdjNmQeSTzGBzduGe/rU4tZhMwL5rWgtp9Ku5Jfo= -go.opentelemetry.io/otel v1.24.0/go.mod h1:W7b9Ozg4nkF5tWI5zsXkaKKDjdVjpD4oAt9Qi/MArHo= -go.opentelemetry.io/otel/metric v1.24.0 h1:6EhoGWWK28x1fbpA4tYTOWBkPefTDQnb8WSGXlc88kI= -go.opentelemetry.io/otel/metric v1.24.0/go.mod h1:VYhLe1rFfxuTXLgj4CBiyz+9WYBA8pNGJgDcSFRKBco= -go.opentelemetry.io/otel/sdk v1.22.0 h1:6coWHw9xw7EfClIC/+O31R8IY3/+EiRFHevmHafB2Gw= -go.opentelemetry.io/otel/sdk v1.22.0/go.mod h1:iu7luyVGYovrRpe2fmj3CVKouQNdTOkxtLzPvPz1DOc= -go.opentelemetry.io/otel/trace v1.24.0 h1:CsKnnL4dUAr/0llH9FKuc698G04IrpWV0MQA/Y1YELI= -go.opentelemetry.io/otel/trace v1.24.0/go.mod h1:HPc3Xr/cOApsBI154IU0OI0HJexz+aw5uPdbs3UCjNU= +go.opentelemetry.io/auto/sdk v1.2.1 h1:jXsnJ4Lmnqd11kwkBV2LgLoFMZKizbCi5fNZ/ipaZ64= +go.opentelemetry.io/auto/sdk v1.2.1/go.mod h1:KRTj+aOaElaLi+wW1kO/DZRXwkF4C5xPbEe3ZiIhN7Y= +go.opentelemetry.io/contrib/detectors/gcp v1.38.0 h1:ZoYbqX7OaA/TAikspPl3ozPI6iY6LiIY9I8cUfm+pJs= +go.opentelemetry.io/contrib/detectors/gcp v1.38.0/go.mod h1:SU+iU7nu5ud4oCb3LQOhIZ3nRLj6FNVrKgtflbaf2ts= +go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.54.0 h1:r6I7RJCN86bpD/FQwedZ0vSixDpwuWREjW9oRMsmqDc= +go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.54.0/go.mod h1:B9yO6b04uB80CzjedvewuqDhxJxi11s7/GtiGa8bAjI= +go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.54.0 h1:TT4fX+nBOA/+LUkobKGW1ydGcn+G3vRw9+g5HwCphpk= +go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.54.0/go.mod h1:L7UH0GbB0p47T4Rri3uHjbpCFYrVrwc1I25QhNPiGK8= +go.opentelemetry.io/otel v1.38.0 h1:RkfdswUDRimDg0m2Az18RKOsnI8UDzppJAtj01/Ymk8= +go.opentelemetry.io/otel v1.38.0/go.mod h1:zcmtmQ1+YmQM9wrNsTGV/q/uyusom3P8RxwExxkZhjM= +go.opentelemetry.io/otel/metric v1.38.0 h1:Kl6lzIYGAh5M159u9NgiRkmoMKjvbsKtYRwgfrA6WpA= +go.opentelemetry.io/otel/metric v1.38.0/go.mod h1:kB5n/QoRM8YwmUahxvI3bO34eVtQf2i4utNVLr9gEmI= +go.opentelemetry.io/otel/sdk v1.38.0 h1:l48sr5YbNf2hpCUj/FoGhW9yDkl+Ma+LrVl8qaM5b+E= +go.opentelemetry.io/otel/sdk v1.38.0/go.mod h1:ghmNdGlVemJI3+ZB5iDEuk4bWA3GkTpW+DOoZMYBVVg= +go.opentelemetry.io/otel/sdk/metric v1.38.0 h1:aSH66iL0aZqo//xXzQLYozmWrXxyFkBJ6qT5wthqPoM= +go.opentelemetry.io/otel/sdk/metric v1.38.0/go.mod h1:dg9PBnW9XdQ1Hd6ZnRz689CbtrUp0wMMs9iPcgT9EZA= +go.opentelemetry.io/otel/trace v1.38.0 h1:Fxk5bKrDZJUH+AMyyIXGcFAPah0oRcT+LuNtJrmcNLE= +go.opentelemetry.io/otel/trace v1.38.0/go.mod h1:j1P9ivuFsTceSWe1oY+EeW3sc+Pp42sO++GHkg4wwhs= go.opentelemetry.io/proto/otlp v0.7.0/go.mod h1:PqfVotwruBrMGOCsRd/89rSnXhoiJIqeYNgFYFoEGnI= go.opentelemetry.io/proto/otlp v0.15.0/go.mod h1:H7XAot3MsfNsj7EXtrA2q5xSNQ10UqI405h3+duxN4U= go.opentelemetry.io/proto/otlp v0.19.0/go.mod h1:H7XAot3MsfNsj7EXtrA2q5xSNQ10UqI405h3+duxN4U= @@ -1242,8 +1301,8 @@ golang.org/x/crypto v0.13.0/go.mod h1:y6Z2r+Rw4iayiXXAIxJIDAJ1zMW4yaTpebo8fPOliY golang.org/x/crypto v0.14.0/go.mod h1:MVFd36DqK4CsrnJYDkBA3VC4m2GkXAM0PvzMCn4JQf4= golang.org/x/crypto v0.19.0/go.mod h1:Iy9bg/ha4yyC70EfRS8jz+B6ybOBKMaSxLj6P6oBDfU= golang.org/x/crypto v0.21.0/go.mod h1:0BP7YvVV9gBbVKyeTG0Gyn+gZm94bibOW5BjDEYAOMs= -golang.org/x/crypto v0.39.0 h1:SHs+kF4LP+f+p14esP5jAoDpHU8Gu/v9lFRK6IT5imM= -golang.org/x/crypto v0.39.0/go.mod h1:L+Xg3Wf6HoL4Bn4238Z6ft6KfEpN0tJGo53AAPC632U= +golang.org/x/crypto v0.43.0 h1:dduJYIi3A3KOfdGOHX8AVZ/jGiyPa3IbBozJ5kNuE04= +golang.org/x/crypto v0.43.0/go.mod h1:BFbav4mRNlXJL4wNeejLpWxB7wMbc79PdRGhWKncxR0= golang.org/x/exp v0.0.0-20180321215751-8460e604b9de/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20180807140117-3d87b88a115f/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= @@ -1306,8 +1365,8 @@ golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= golang.org/x/mod v0.9.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= golang.org/x/mod v0.12.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= golang.org/x/mod v0.14.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c= -golang.org/x/mod v0.25.0 h1:n7a+ZbQKQA/Ysbyb0/6IbB1H/X41mKgbhfv7AfG/44w= -golang.org/x/mod v0.25.0/go.mod h1:IXM97Txy2VM4PJ3gI61r1YEk/gAj6zAHN3AdZt6S9Ww= +golang.org/x/mod v0.28.0 h1:gQBtGhjxykdjY9YhZpSlZIsbnaE2+PgjfLWUQTnoZ1U= +golang.org/x/mod v0.28.0/go.mod h1:yfB/L0NOf/kmEbXjzCPOx1iK1fRutOydrCMsqRhEBxI= golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= @@ -1377,8 +1436,8 @@ golang.org/x/net v0.14.0/go.mod h1:PpSgVXXLK0OxS0F31C1/tv6XNguvCrnXIDrFMspZIUI= golang.org/x/net v0.15.0/go.mod h1:idbUs1IY1+zTqbi8yxTbhexhEEk5ur9LInksu6HrEpk= golang.org/x/net v0.16.0/go.mod h1:NxSsAGuq816PNPmqtQdLE42eU2Fs7NoRIZrHJAlaCOE= golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44= -golang.org/x/net v0.40.0 h1:79Xs7wF06Gbdcg4kdCCIQArK11Z1hr5POQ6+fIYHNuY= -golang.org/x/net v0.40.0/go.mod h1:y0hY0exeL2Pku80/zKK7tpntoX23cqL3Oa6njdgRtds= +golang.org/x/net v0.46.1-0.20251013234738-63d1a5100f82 h1:6/3JGEh1C88g7m+qzzTbl3A0FtsLguXieqofVLU/JAo= +golang.org/x/net v0.46.1-0.20251013234738-63d1a5100f82/go.mod h1:Q9BGdFy1y4nkUwiLvT5qtyhAnEHgnQ/zd8PfU6nc210= golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= golang.org/x/oauth2 v0.0.0-20181017192945-9dcd33a902f4/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= golang.org/x/oauth2 v0.0.0-20181203162652-d668ce993890/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= @@ -1412,8 +1471,8 @@ golang.org/x/oauth2 v0.6.0/go.mod h1:ycmewcwgD4Rpr3eZJLSB4Kyyljb3qDh40vJ8STE5HKw golang.org/x/oauth2 v0.7.0/go.mod h1:hPLQkd9LyjfXTiRohC/41GhcFqxisoUQ99sCUOHO9x4= golang.org/x/oauth2 v0.8.0/go.mod h1:yr7u4HXZRm1R1kBWqr/xKNqewf0plRYoB7sla+BCIXE= golang.org/x/oauth2 v0.13.0/go.mod h1:/JMhi4ZRXAf4HG9LiNmxvk+45+96RUlVThiH8FzNBn0= -golang.org/x/oauth2 v0.30.0 h1:dnDm7JmhM45NNpd8FDDeLhK6FwqbOf4MLCM9zb1BOHI= -golang.org/x/oauth2 v0.30.0/go.mod h1:B++QgG3ZKulg6sRPGD/mqlHQs5rB3Ml9erfeDY7xKlU= +golang.org/x/oauth2 v0.32.0 h1:jsCblLleRMDrxMN29H3z/k1KliIvpLgCkE6R8FXXNgY= +golang.org/x/oauth2 v0.32.0/go.mod h1:lzm5WQJQwKZ3nwavOZ3IS5Aulzxi68dUSgRHujetwEA= golang.org/x/perf v0.0.0-20180704124530-6e6d33e29852/go.mod h1:JLpeXjPJfIyPr5TlbXLkXWLhP8nz10XfvxElABhCtcw= golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= @@ -1433,8 +1492,8 @@ golang.org/x/sync v0.0.0-20220929204114-8fcdb60fdcc0/go.mod h1:RxMgew5VJxzue5/jJ golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.2.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.3.0/go.mod h1:FU7BRWz2tNW+3quACPkgCx/L+uEAv1htQ0V83Z9Rj+Y= -golang.org/x/sync v0.15.0 h1:KWH3jNZsfyT6xfAfKiz6MRNmd46ByHDYaZ7KSkCtdW8= -golang.org/x/sync v0.15.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA= +golang.org/x/sync v0.17.0 h1:l60nONMj9l5drqw6jlhIELNv9I0A4OFgRsG9k2oT9Ug= +golang.org/x/sync v0.17.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI= golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20180909124046-d0be0721c37e/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20181029174526-d69651ed3497/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= @@ -1526,8 +1585,8 @@ golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.13.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.17.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/sys v0.18.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= -golang.org/x/sys v0.33.0 h1:q3i8TbbEz+JRD9ywIRlyRAQbM0qF7hu24q3teo2hbuw= -golang.org/x/sys v0.33.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= +golang.org/x/sys v0.37.0 h1:fdNQudmxPjkdUTPnLn5mdQv7Zwvbvpaxqs831goi9kQ= +golang.org/x/sys v0.37.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= golang.org/x/term v0.1.0/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= @@ -1543,8 +1602,8 @@ golang.org/x/term v0.12.0/go.mod h1:owVbMEjm3cBLCHdkQu9b1opXd4ETQWc3BhuQGKgXgvU= golang.org/x/term v0.13.0/go.mod h1:LTmsnFJwVN6bCy1rVCoS+qHT1HhALEFxKncY3WNNh4U= golang.org/x/term v0.17.0/go.mod h1:lLRBjIVuehSbZlaOtGMbcMncT+aqLLLmKrsjNrUguwk= golang.org/x/term v0.18.0/go.mod h1:ILwASektA3OnRv7amZ1xhE/KTR+u50pbXfZ03+6Nx58= -golang.org/x/term v0.32.0 h1:DR4lr0TjUs3epypdhTOkMmuF5CDFJ/8pOnbzMZPQ7bg= -golang.org/x/term v0.32.0/go.mod h1:uZG1FhGx848Sqfsq4/DlJr3xGGsYMu/L5GW4abiaEPQ= +golang.org/x/term v0.36.0 h1:zMPR+aF8gfksFprF/Nc/rd1wRS1EI6nDBGyWAvDzx2Q= +golang.org/x/term v0.36.0/go.mod h1:Qu394IJq6V6dCBRgwqshf3mPF85AqzYEzofzRdZkWss= golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.1-0.20180807135948-17ff2d5776d2/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= @@ -1564,8 +1623,8 @@ golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= golang.org/x/text v0.12.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE= golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE= golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= -golang.org/x/text v0.26.0 h1:P42AVeLghgTYr4+xUnTRKDMqpar+PtX7KWuNQL21L8M= -golang.org/x/text v0.26.0/go.mod h1:QK15LZJUUQVJxhz7wXgxSy/CJaTFjd0G+YLonydOVQA= +golang.org/x/text v0.30.0 h1:yznKA/E9zq54KzlzBEAWn1NXSQ8DIp/NYMy88xJjl4k= +golang.org/x/text v0.30.0/go.mod h1:yDdHFIX9t+tORqspjENWgzaCVXgk0yYnYuSZ8UzzBVM= golang.org/x/time v0.0.0-20180412165947-fbb02b2291d2/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= @@ -1573,8 +1632,8 @@ golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxb golang.org/x/time v0.0.0-20220922220347-f3bd1da661af/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.1.0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.3.0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= -golang.org/x/time v0.5.0 h1:o7cqy6amK/52YcAKIPlM3a+Fpj35zvRj2TP+e1xFSfk= -golang.org/x/time v0.5.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM= +golang.org/x/time v0.7.0 h1:ntUhktv3OPE6TgYxXWv9vKvUSJyIFJlyohwbkEwPrKQ= +golang.org/x/time v0.7.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM= golang.org/x/tools v0.0.0-20180525024113-a5b4c53f6e8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20180828015842-6cd1fcedba52/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= @@ -1639,8 +1698,8 @@ golang.org/x/tools v0.3.0/go.mod h1:/rWhSS2+zyEVwoJf8YAX6L2f0ntZ7Kn/mGgAWcipA5k= golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU= golang.org/x/tools v0.7.0/go.mod h1:4pg6aUX35JBAogB10C9AtvVL+qowtN4pT3CGSQex14s= golang.org/x/tools v0.13.0/go.mod h1:HvlwmtVNQAhOuCjW7xxvovg8wbNq7LwfXh/k7wXUl58= -golang.org/x/tools v0.33.0 h1:4qz2S3zmRxbGIhDIAgjxvFutSvH5EfnsYrRBj0UI0bc= -golang.org/x/tools v0.33.0/go.mod h1:CIJMaWEY88juyUfo7UbgPqbC8rU2OqfAV1h2Qp0oMYI= +golang.org/x/tools v0.37.0 h1:DVSRzp7FwePZW356yEAChSdNcQo6Nsp+fex1SUW09lE= +golang.org/x/tools v0.37.0/go.mod h1:MBN5QPQtLMHVdvsbtarmTNukZDdgwdwlO5qGacAzF0w= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= @@ -1649,12 +1708,12 @@ golang.org/x/xerrors v0.0.0-20220411194840-2f41105eb62f/go.mod h1:I/5z698sn9Ka8T golang.org/x/xerrors v0.0.0-20220517211312-f3a8303e98df/go.mod h1:K8+ghG5WaK9qNqU5K3HdILfMLy1f3aNYFI/wnl100a8= golang.org/x/xerrors v0.0.0-20220609144429-65e65417b02f/go.mod h1:K8+ghG5WaK9qNqU5K3HdILfMLy1f3aNYFI/wnl100a8= golang.org/x/xerrors v0.0.0-20220907171357-04be3eba64a2/go.mod h1:K8+ghG5WaK9qNqU5K3HdILfMLy1f3aNYFI/wnl100a8= -golang.org/x/xerrors v0.0.0-20231012003039-104605ab7028 h1:+cNy6SZtPcJQH3LJVLOSmiC7MMxXNOb3PU/VUEz+EhU= -golang.org/x/xerrors v0.0.0-20231012003039-104605ab7028/go.mod h1:NDW/Ps6MPRej6fsCIbMTohpP40sJ/P/vI1MoTEGwX90= gonum.org/v1/gonum v0.0.0-20180816165407-929014505bf4/go.mod h1:Y+Yx5eoAFn32cQvJDxZx5Dpnq+c3wtXuadVZAcxbbBo= gonum.org/v1/gonum v0.8.2/go.mod h1:oe/vMfY3deqTw+1EZJhuvEW2iwGF1bW9wwu7XCu0+v0= gonum.org/v1/gonum v0.9.3/go.mod h1:TZumC3NeyVQskjXqmyWt4S3bINhy7B4eYwW69EbyX+0= gonum.org/v1/gonum v0.11.0/go.mod h1:fSG4YDCxxUZQJ7rKsQrj0gMOg00Il0Z96/qMA4bVQhA= +gonum.org/v1/gonum v0.16.0 h1:5+ul4Swaf3ESvrOnidPp4GZbzf0mxVQpDCYUQE7OJfk= +gonum.org/v1/gonum v0.16.0/go.mod h1:fef3am4MQ93R2HHpKnLk4/Tbh/s0+wqD5nfa6Pnwy4E= gonum.org/v1/netlib v0.0.0-20190313105609-8cb42192e0e0/go.mod h1:wa6Ws7BG/ESfp6dHfk7C6KdzKA7wR7u/rKwOGE66zvw= gonum.org/v1/plot v0.0.0-20190515093506-e2840ee46a6b/go.mod h1:Wt8AAjI+ypCyYX3nZBvf6cAIx93T+c/OS2HFAYskSZc= gonum.org/v1/plot v0.9.0/go.mod h1:3Pcqqmp6RHvJI72kgb8fThyUnav364FOsdDo2aGW5lY= @@ -1723,8 +1782,8 @@ google.golang.org/api v0.118.0/go.mod h1:76TtD3vkgmZ66zZzp72bUUklpmQmKlhh6sYtIjY google.golang.org/api v0.122.0/go.mod h1:gcitW0lvnyWjSp9nKxAbdHKIZ6vF4aajGueeslZOyms= google.golang.org/api v0.124.0/go.mod h1:xu2HQurE5gi/3t1aFCvhPD781p0a3p11sdunTJ2BlP4= google.golang.org/api v0.126.0/go.mod h1:mBwVAtz+87bEN6CbA1GtZPDOqY2R5ONPqJeIlvyo4Aw= -google.golang.org/api v0.174.0 h1:zB1BWl7ocxfTea2aQ9mgdzXjnfPySllpPOskdnO+q34= -google.golang.org/api v0.174.0/go.mod h1:aC7tB6j0HR1Nl0ni5ghpx6iLasmAX78Zkh/wgxAAjLg= +google.golang.org/api v0.201.0 h1:+7AD9JNM3tREtawRMu8sOjSbb8VYcYXJG/2eEOmfDu0= +google.golang.org/api v0.201.0/go.mod h1:HVY0FCHVs89xIW9fzf/pBvOEm+OolHa86G/txFezyq4= google.golang.org/appengine v1.1.0/go.mod h1:EbEs0AVv82hx2wNQdGPgUI5lhzA/G0D9YwlJXL52JkM= google.golang.org/appengine v1.2.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= google.golang.org/appengine v1.3.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= @@ -1877,21 +1936,21 @@ google.golang.org/genproto v0.0.0-20230403163135-c38d8f061ccd/go.mod h1:UUQDJDOl google.golang.org/genproto v0.0.0-20230410155749-daa745c078e1/go.mod h1:nKE/iIaLqn2bQwXBg8f1g2Ylh6r5MN5CmZvuzZCgsCU= google.golang.org/genproto v0.0.0-20230525234025-438c736192d0/go.mod h1:9ExIQyXL5hZrHzQceCwuSYwZZ5QZBazOcprJ5rgs3lY= google.golang.org/genproto v0.0.0-20230530153820-e85fd2cbaebc/go.mod h1:xZnkP7mREFX5MORlOPEzLMr+90PPZQ2QWzrVTWfAq64= -google.golang.org/genproto v0.0.0-20240227224415-6ceb2ff114de h1:F6qOa9AZTYJXOUEr4jDysRDLrm4PHePlge4v4TGAlxY= -google.golang.org/genproto v0.0.0-20240227224415-6ceb2ff114de/go.mod h1:VUhTRKeHn9wwcdrk73nvdC9gF178Tzhmt/qyaFcPLSo= +google.golang.org/genproto v0.0.0-20241015192408-796eee8c2d53 h1:Df6WuGvthPzc+JiQ/G+m+sNX24kc0aTBqoDN/0yyykE= +google.golang.org/genproto v0.0.0-20241015192408-796eee8c2d53/go.mod h1:fheguH3Am2dGp1LfXkrvwqC/KlFq8F0nLq3LryOMrrE= google.golang.org/genproto/googleapis/api v0.0.0-20230525234020-1aefcd67740a/go.mod h1:ts19tUU+Z0ZShN1y3aPyq2+O3d5FUNNgT6FtOzmrNn8= google.golang.org/genproto/googleapis/api v0.0.0-20230525234035-dd9d682886f9/go.mod h1:vHYtlOoi6TsQ3Uk2yxR7NI5z8uoV+3pZtR4jmHIkRig= google.golang.org/genproto/googleapis/api v0.0.0-20230526203410-71b5a4ffd15e/go.mod h1:vHYtlOoi6TsQ3Uk2yxR7NI5z8uoV+3pZtR4jmHIkRig= google.golang.org/genproto/googleapis/api v0.0.0-20230530153820-e85fd2cbaebc/go.mod h1:vHYtlOoi6TsQ3Uk2yxR7NI5z8uoV+3pZtR4jmHIkRig= -google.golang.org/genproto/googleapis/api v0.0.0-20240314234333-6e1732d8331c h1:kaI7oewGK5YnVwj+Y+EJBO/YN1ht8iTL9XkFHtVZLsc= -google.golang.org/genproto/googleapis/api v0.0.0-20240314234333-6e1732d8331c/go.mod h1:VQW3tUculP/D4B+xVCo+VgSq8As6wA9ZjHl//pmk+6s= +google.golang.org/genproto/googleapis/api v0.0.0-20251022142026-3a174f9686a8 h1:mepRgnBZa07I4TRuomDE4sTIYieg/osKmzIf4USdWS4= +google.golang.org/genproto/googleapis/api v0.0.0-20251022142026-3a174f9686a8/go.mod h1:fDMmzKV90WSg1NbozdqrE64fkuTv6mlq2zxo9ad+3yo= google.golang.org/genproto/googleapis/bytestream v0.0.0-20230530153820-e85fd2cbaebc/go.mod h1:ylj+BE99M198VPbBh6A8d9n3w8fChvyLK3wwBOjXBFA= google.golang.org/genproto/googleapis/rpc v0.0.0-20230525234015-3fc162c6f38a/go.mod h1:xURIpW9ES5+/GZhnV6beoEtxQrnkRGIfP5VQG2tCBLc= google.golang.org/genproto/googleapis/rpc v0.0.0-20230525234030-28d5490b6b19/go.mod h1:66JfowdXAEgad5O9NnYcsNPLCPZJD++2L9X0PCMODrA= google.golang.org/genproto/googleapis/rpc v0.0.0-20230526203410-71b5a4ffd15e/go.mod h1:66JfowdXAEgad5O9NnYcsNPLCPZJD++2L9X0PCMODrA= google.golang.org/genproto/googleapis/rpc v0.0.0-20230530153820-e85fd2cbaebc/go.mod h1:66JfowdXAEgad5O9NnYcsNPLCPZJD++2L9X0PCMODrA= -google.golang.org/genproto/googleapis/rpc v0.0.0-20240415180920-8c6c420018be h1:LG9vZxsWGOmUKieR8wPAUR3u3MpnYFQZROPIMaXh7/A= -google.golang.org/genproto/googleapis/rpc v0.0.0-20240415180920-8c6c420018be/go.mod h1:WtryC6hu0hhx87FDGxWCDptyssuo68sk10vYjF+T9fY= +google.golang.org/genproto/googleapis/rpc v0.0.0-20251022142026-3a174f9686a8 h1:M1rk8KBnUsBDg1oPGHNCxG4vc1f49epmTO7xscSajMk= +google.golang.org/genproto/googleapis/rpc v0.0.0-20251022142026-3a174f9686a8/go.mod h1:7i2o+ce6H/6BluujYR+kqX3GKH+dChPTQU19wjRPiGk= google.golang.org/grpc v1.14.0/go.mod h1:yo6s7OP7yaDglbqo1J04qKzAhqBH6lvTonzMVmEdcZw= google.golang.org/grpc v1.16.0/go.mod h1:0JHn/cJsOMiMfNA9+DeHDlAU7KAAB5GDlYFpa9MZMio= google.golang.org/grpc v1.17.0/go.mod h1:6QZJwpn2B+Zp71q/5VxRsJ6NXXVCE5NRUHRo+f3cWCs= @@ -1936,8 +1995,8 @@ google.golang.org/grpc v1.52.0/go.mod h1:pu6fVzoFb+NBYNAvQL08ic+lvB2IojljRYuun5v google.golang.org/grpc v1.53.0/go.mod h1:OnIrk0ipVdj4N5d9IUoFUx72/VlD7+jUsHwZgwSMQpw= google.golang.org/grpc v1.54.0/go.mod h1:PUSEXI6iWghWaB6lXM4knEgpJNu2qUcKfDtNci3EC2g= google.golang.org/grpc v1.55.0/go.mod h1:iYEXKGkEBhg1PjZQvoYEVPTDkHo1/bjTnfwTeGONTY8= -google.golang.org/grpc v1.63.2 h1:MUeiw1B2maTVZthpU5xvASfTh3LDbxHd6IJ6QQVU+xM= -google.golang.org/grpc v1.63.2/go.mod h1:WAX/8DgncnokcFUldAxq7GeB5DXHDbMF+lLvDomNkRA= +google.golang.org/grpc v1.77.0 h1:wVVY6/8cGA6vvffn+wWK5ToddbgdU3d8MNENr4evgXM= +google.golang.org/grpc v1.77.0/go.mod h1:z0BY1iVj0q8E1uSQCjL9cppRj+gnZjzDnzV0dHhrNig= google.golang.org/grpc/cmd/protoc-gen-go-grpc v1.1.0/go.mod h1:6Kw0yEErY5E/yWrBtf03jp27GLLJujG4z/JK95pnjjw= google.golang.org/protobuf v0.0.0-20200109180630-ec00e32a8dfd/go.mod h1:DFci5gLYBciE7Vtevhsrf46CRTquxDuWsQurQQe4oz8= google.golang.org/protobuf v0.0.0-20200221191635-4d8936d0db64/go.mod h1:kwYJMbMJ01Woi6D6+Kah6886xMZcty6N08ah7+eCXa0= @@ -1957,8 +2016,8 @@ google.golang.org/protobuf v1.28.1/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqw google.golang.org/protobuf v1.29.1/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= google.golang.org/protobuf v1.30.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= google.golang.org/protobuf v1.31.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= -google.golang.org/protobuf v1.33.0 h1:uNO2rsAINq/JlFpSdYEKIZ0uKD/R9cpdv0T+yoGwGmI= -google.golang.org/protobuf v1.33.0/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos= +google.golang.org/protobuf v1.36.10 h1:AYd7cD/uASjIL6Q9LiTjz8JLcrh/88q5UObnmY3aOOE= +google.golang.org/protobuf v1.36.10/go.mod h1:HTf+CrKn2C3g5S8VImy6tdcUvCska2kB7j23XfzDpco= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= diff --git a/internal/codegen/ast/assign.go b/internal/codegen/ast/assign.go index f2f9d4756..f058ff248 100644 --- a/internal/codegen/ast/assign.go +++ b/internal/codegen/ast/assign.go @@ -54,7 +54,7 @@ func (s *Assign) Generate(builder *Builder) (err error) { return nil } - if err = builder.WriteString("\n"); err != nil { + if err = builder.WriteIndentedString("\n"); err != nil { return err } asIdent, ok := s.Holder.(*Ident) @@ -84,7 +84,6 @@ func (s *Assign) Generate(builder *Builder) (err error) { if err = s.Expression.Generate(builder); err != nil { return err } - builder.WriteString("\n") if !wasDeclared { builder.State.DeclareVariable(asIdent.Name) } diff --git a/internal/codegen/ast/condition.go b/internal/codegen/ast/condition.go index 7c3a37509..60b88b3a5 100644 --- a/internal/codegen/ast/condition.go +++ b/internal/codegen/ast/condition.go @@ -83,10 +83,6 @@ func (s *Condition) Generate(builder *Builder) (err error) { } bodyBlockBuilder := builder.IncIndent(" ") - if err = bodyBlockBuilder.WriteIndentedString("\n"); err != nil { - return err - } - if err = s.IFBlock.Generate(bodyBlockBuilder); err != nil { return err } @@ -108,10 +104,6 @@ func (s *Condition) Generate(builder *Builder) (err error) { return err } - if err = bodyBlockBuilder.WriteIndentedString("\n"); err != nil { - return err - } - if err = block.Block.Generate(bodyBlockBuilder); err != nil { return err } @@ -130,10 +122,6 @@ func (s *Condition) Generate(builder *Builder) (err error) { return err } - if err = bodyBlockBuilder.WriteIndentedString("\n"); err != nil { - return err - } - if err = s.ElseBlock.Generate(bodyBlockBuilder); err != nil { return err } diff --git a/internal/codegen/handler.go b/internal/codegen/handler.go index 5c732a16e..dafbd99a8 100644 --- a/internal/codegen/handler.go +++ b/internal/codegen/handler.go @@ -2,11 +2,12 @@ package codegen import ( _ "embed" + "strings" + "github.com/viant/datly/cmd/options" "github.com/viant/datly/internal/codegen/ast" "github.com/viant/datly/internal/inference" "github.com/viant/datly/internal/plugin" - "strings" ) //go:embed tmpl/handler/handler.gox diff --git a/internal/inference/join_test.go b/internal/inference/join_test.go new file mode 100644 index 000000000..ea30ddb95 --- /dev/null +++ b/internal/inference/join_test.go @@ -0,0 +1,56 @@ +package inference + +import ( + "testing" + + "github.com/stretchr/testify/require" + "github.com/viant/sqlparser" +) + +func TestJoinRelationExtraction(t *testing.T) { + testCases := []struct { + description string + sql string + wantParent string + wantRelCol string + wantRefCol string + }{ + { + description: "simple join", + sql: "SELECT * FROM a a JOIN b b ON a.brand = b.b_brand", + wantParent: "a", + wantRelCol: "brand", + wantRefCol: "b_brand", + }, + { + description: "join with function on parent", + sql: "SELECT * FROM a a JOIN b b ON lower(a.brand) = b.b_brand", + wantParent: "a", + wantRelCol: "brand", + wantRefCol: "b_brand", + }, + { + description: "join with collate and multiple conditions", + sql: "SELECT * FROM a a JOIN b b ON " + + "a.brand COLLATE utf8mb4_bin = b.b_brand COLLATE utf8mb4_bin AND " + + "a.model COLLATE utf8mb4_bin = b.b_model COLLATE utf8mb4_bin", + wantParent: "a", + wantRelCol: "brand", + wantRefCol: "b_brand", + }, + } + + for _, testCase := range testCases { + q, err := sqlparser.ParseQuery(testCase.sql) + require.NoError(t, err, testCase.description) + require.NotEmpty(t, q.Joins, testCase.description) + + join := q.Joins[0] + parent := ParentAlias(join) + require.Equal(t, testCase.wantParent, parent, testCase.description) + + relCol, refCol := ExtractRelationColumns(join) + require.Equal(t, testCase.wantRelCol, relCol, testCase.description) + require.Equal(t, testCase.wantRefCol, refCol, testCase.description) + } +} diff --git a/internal/inference/parameter.go b/internal/inference/parameter.go index d3403859f..cc2916e69 100644 --- a/internal/inference/parameter.go +++ b/internal/inference/parameter.go @@ -4,6 +4,12 @@ import ( "embed" _ "embed" "fmt" + "go/ast" + "path" + "reflect" + "strconv" + "strings" + "github.com/viant/datly/view" "github.com/viant/datly/view/state" "github.com/viant/datly/view/tags" @@ -15,11 +21,6 @@ import ( "github.com/viant/tagly/format/text" "github.com/viant/xreflect" "github.com/viant/xunsafe" - "go/ast" - "path" - "reflect" - "strconv" - "strings" ) type ( @@ -34,6 +35,7 @@ type ( AssumedType bool Connector string Cache string + Limit *int InOutput bool Of string } @@ -78,18 +80,19 @@ func (p *Parameter) veltyDeclaration(builder *strings.Builder) { case state.KindParam: builder.WriteString("?") default: + isPtr := strings.HasPrefix(p.Schema.DataType, "*") if p.Schema.Cardinality == state.Many { builder.WriteString("[]") - switch p.In.Kind { case "query", "form", "header": default: - if !p.IsRequired() { + if !p.IsRequired() && !isPtr { + isPtr = true builder.WriteString("*") } } - } else if !p.IsRequired() { + } else if !p.IsRequired() && !isPtr { builder.WriteString("*") } builder.WriteString(p.Schema.DataType) @@ -115,6 +118,10 @@ func (p *Parameter) veltyDeclaration(builder *strings.Builder) { builder.WriteString(".WithCache('" + p.Cache + "')") } + if p.Limit != nil { + builder.WriteString(".WithLimit('" + strconv.Itoa(*p.Limit) + "')") + } + if p.Required != nil { if !*p.Required { builder.WriteString(".Optional()") @@ -122,6 +129,10 @@ func (p *Parameter) veltyDeclaration(builder *strings.Builder) { builder.WriteString(".Required()") } } + + if p.Cacheable != nil { + builder.WriteString(".Cacheable('" + strconv.FormatBool(*p.Cacheable) + "')") + } if p.Connector != "" { builder.WriteString(".WithConnector('" + p.Connector + "')") } @@ -304,6 +315,9 @@ func buildParameter(field *xunsafe.Field, aTag *tags.Tag, types *xreflect.Types, if aTag.View.Cache != "" { param.Cache = aTag.View.Cache } + if aTag.View.Limit != nil { + param.Limit = aTag.View.Limit + } } fType := field.Type @@ -340,16 +354,9 @@ func ParentAlias(join *query.Join) string { result := "" sqlparser.Traverse(join.On, func(n node.Node) bool { switch actual := n.(type) { - case *qexpr.Binary: - if xSel, ok := actual.X.(*qexpr.Selector); ok { - if xSel.Name != join.Alias { - result = xSel.Name - } - } - if ySel, ok := actual.Y.(*qexpr.Selector); ok { - if ySel.Name != join.Alias { - result = ySel.Name - } + case *qexpr.Selector: + if actual.Name != "" && actual.Name != join.Alias { + result = actual.Name } return true } @@ -363,20 +370,14 @@ func ExtractRelationColumns(join *query.Join) (string, string) { refColumn := "" sqlparser.Traverse(join.On, func(n node.Node) bool { switch actual := n.(type) { - case *qexpr.Binary: - if xSel, ok := actual.X.(*qexpr.Selector); ok { - if xSel.Name == join.Alias { - refColumn = sqlparser.Stringify(xSel.X) - } else if relColumn == "" { - relColumn = sqlparser.Stringify(xSel.X) - } - } - if ySel, ok := actual.Y.(*qexpr.Selector); ok { - if ySel.Name == join.Alias { - refColumn = sqlparser.Stringify(ySel.X) - } else if relColumn == "" { - relColumn = sqlparser.Stringify(ySel.X) + case *qexpr.Selector: + column := sqlparser.Stringify(actual.X) + if actual.Name == join.Alias { + if refColumn == "" { + refColumn = column } + } else if relColumn == "" { + relColumn = column } return true } diff --git a/internal/inference/spec.go b/internal/inference/spec.go index 52fc6068d..e6a562199 100644 --- a/internal/inference/spec.go +++ b/internal/inference/spec.go @@ -4,6 +4,9 @@ import ( "context" "database/sql" "fmt" + "reflect" + "strings" + "github.com/viant/datly/internal/msg" "github.com/viant/datly/view" "github.com/viant/datly/view/column" @@ -14,8 +17,6 @@ import ( "github.com/viant/sqlx/metadata/info" "github.com/viant/sqlx/metadata/sink" "github.com/viant/sqlx/option" - "reflect" - "strings" ) type ( @@ -246,7 +247,13 @@ func NewSpec(ctx context.Context, db *sql.DB, messages *msg.Messages, table stri var result = &Spec{Table: table, SQL: SQL, SQLArgs: SQLArgs, IsAuxiliary: isAuxiliary} columns, err := column.Discover(ctx, db, table, SQL, SQLArgs...) if err != nil { - return nil, err + columns = bestEffortColumnsFromSQL(SQL, columnsConfig) + if len(columns) == 0 { + return nil, err + } + if messages != nil { + messages.AddWarning(result.Table, "detection", fmt.Sprintf("using best-effort SQL column inference due to discovery error: %v", err)) + } } result.Columns = columns byName := result.Columns.ByName() @@ -285,6 +292,56 @@ func NewSpec(ctx context.Context, db *sql.DB, messages *msg.Messages, table stri return result, nil } +func bestEffortColumnsFromSQL(SQL string, columnsConfig view.ColumnConfigs) sqlparser.Columns { + if strings.TrimSpace(SQL) == "" { + return nil + } + query, err := sqlparser.ParseQuery(SQL) + if err != nil || query == nil { + return nil + } + queryColumns := sqlparser.NewColumns(query.List) + if len(queryColumns) == 0 { + return nil + } + cfgByLower := map[string]*view.ColumnConfig{} + for _, cfg := range columnsConfig { + if cfg == nil || cfg.Name == "" { + continue + } + cfgByLower[strings.ToLower(cfg.Name)] = cfg + } + var result sqlparser.Columns + for _, candidate := range queryColumns { + if candidate == nil { + continue + } + expression := strings.TrimSpace(candidate.Expression) + if expression == "*" || strings.HasSuffix(expression, ".*") { + continue + } + name := strings.TrimSpace(candidate.Alias) + if name == "" { + name = strings.TrimSpace(candidate.Name) + } + if name == "" { + continue + } + if candidate.Type == "" { + if cfg, ok := cfgByLower[strings.ToLower(name)]; ok && cfg.DataType != nil && *cfg.DataType != "" { + candidate.Type = *cfg.DataType + } else if cfg, ok = cfgByLower[strings.ToLower(candidate.Name)]; ok && cfg.DataType != nil && *cfg.DataType != "" { + candidate.Type = *cfg.DataType + } + } + if candidate.Type == "" { + candidate.Type = "string" + } + result = append(result, candidate) + } + return result +} + func isAuxiliary(SQL *string) bool { if *SQL == "" { return false diff --git a/internal/inference/state.go b/internal/inference/state.go index e00c2b3b0..fa28c984f 100644 --- a/internal/inference/state.go +++ b/internal/inference/state.go @@ -3,6 +3,12 @@ package inference import ( "context" "fmt" + "go/ast" + "go/parser" + "path" + "reflect" + "strings" + "github.com/viant/afs" "github.com/viant/afs/embed" "github.com/viant/afs/file" @@ -19,11 +25,6 @@ import ( "github.com/viant/toolbox/data" "github.com/viant/xreflect" "github.com/viant/xunsafe" - "go/ast" - "go/parser" - "path" - "reflect" - "strings" ) // State defines datly view/resource parameters @@ -490,7 +491,13 @@ func (s State) EnsureReflectTypes(modulePath string, pkg string, registry *xrefl if err != nil { rType, err = types.LookupType(typeRegistry.Lookup, dataType, xreflect.WithPackage(pkg)) if err != nil { - return err + rType = reflect.TypeOf((*interface{})(nil)).Elem() + if param.Schema.DataType == "" { + param.Schema.DataType = "interface{}" + } + if param.Schema.Package == "" { + param.Schema.Package = pkg + } } } param.Schema.SetType(rType) @@ -691,12 +698,20 @@ func NewState(packageLocation, dataType string, types *xreflect.Types) (State, e } state.BuildPredicate(aTag, ¶m.Parameter) state.BuildCodec(aTag, ¶m.Parameter) + if param.Schema.DataType == "" { compType := param.Schema.CompType() + paramTypeName := compType.String() + if compType.Kind() == reflect.Pointer { compType = compType.Elem() + paramTypeName := compType.String() + + if compType.Kind() == reflect.Struct { + paramTypeName = "*" + paramTypeName + } } - param.Schema.DataType = compType.String() + param.Schema.DataType = paramTypeName param.Schema.PackagePath = compType.PkgPath() } //} @@ -776,6 +791,13 @@ func discoverStateType(baseDir string, types *xreflect.Types, dataType string, p return nil, err } var rType = xunsafe.LookupType(dirTypes.ModulePath + "/" + dataType) + + if rType == nil && types != nil && strings.Count(pkg, "/") > 1 { //the last resort fallback collission protection + pkg = strings.Replace(pkg, "pkg/", "", 1) + rType, _ = types.Lookup(dataType, xreflect.WithPackage(pkg)) + + } + if rType == nil && len(stateTypeFields) > 0 { rType = reflect.StructOf(stateTypeFields) } diff --git a/internal/inference/struct.go b/internal/inference/struct.go index 09affc9be..03cdf0cc5 100644 --- a/internal/inference/struct.go +++ b/internal/inference/struct.go @@ -35,16 +35,30 @@ func (p *parameterStruct) Add(name string, parameter *Parameter) { } func (p *parameterStruct) reflectType() reflect.Type { - return p.structField().Type + field := p.structField() + return field.Type } func (p *parameterStruct) structField() reflect.StructField { - if p.Parameter != nil && (p.Parameter.In.Kind != state.KindObject) { + if p == nil { + return reflect.StructField{} + } + if p.Parameter != nil && (p.Parameter.In == nil || p.Parameter.In.Kind != state.KindObject) { return reflect.StructField{Name: p.name, Type: p.Parameter.Schema.Type(), Tag: reflect.StructTag(p.Parameter.Tag), PkgPath: xreflect.PkgPath(p.Parameter.Name, p.Parameter.Schema.Package)} } var fields []reflect.StructField for _, f := range p.fields { - fields = append(fields, f.structField()) + if f == nil { + continue + } + field := f.structField() + if field.Name == "" || field.Type == nil { + continue + } + fields = append(fields, field) + } + if len(fields) == 0 { + return reflect.StructField{Name: p.name, Type: reflect.TypeOf(struct{}{})} } pkgPath := "" if p.name != "" { diff --git a/internal/inference/type.go b/internal/inference/type.go index 3694aea03..3982045b4 100644 --- a/internal/inference/type.go +++ b/internal/inference/type.go @@ -229,11 +229,13 @@ func NewType(packageName string, name string, rType reflect.Type) (*Type, error) rType = types.EnsureStruct(rType) if rType.NumField() == 1 { wrapperField := rType.Field(0) - if canidateType, _ := wrapperField.Tag.Lookup("typeName"); canidateType != "" { - name = canidateType + if types.EnsureStruct(wrapperField.Type) != nil { + if canidateType, _ := wrapperField.Tag.Lookup("typeName"); canidateType != "" { + name = canidateType + } + structType := types.EnsureStruct(wrapperField.Type) + return NewType(packageName, name, structType) } - structType := types.EnsureStruct(wrapperField.Type) - return NewType(packageName, name, structType) } for i := 0; i < rType.NumField(); i++ { diff --git a/internal/testutil/sqlnormalizer/cases.go b/internal/testutil/sqlnormalizer/cases.go new file mode 100644 index 000000000..73569e24a --- /dev/null +++ b/internal/testutil/sqlnormalizer/cases.go @@ -0,0 +1,43 @@ +package sqlnormalizer + +type Case struct { + Name string + Generated bool + SQL string + Expect string +} + +func Cases() []Case { + return []Case{ + { + Name: "skip normalization when not generated", + Generated: false, + SQL: "SELECT a.id FROM users a JOIN orders b ON a.id = b.user_id", + Expect: "SELECT a.id FROM users a JOIN orders b ON a.id = b.user_id", + }, + { + Name: "invalid sql returns input", + Generated: true, + SQL: "SELECT * FROM (", + Expect: "SELECT * FROM (", + }, + { + Name: "normalize from and join aliases in selectors and alias nodes", + Generated: true, + SQL: "SELECT a.id, b.user_id FROM users a JOIN orders b ON a.id = b.user_id", + Expect: "SELECT A.id, B.user_id FROM users A JOIN orders B ON A.id = B.user_id", + }, + { + Name: "keep alias that is already normalized", + Generated: true, + SQL: "SELECT UserAlias.id FROM users UserAlias", + Expect: "SELECT UserAlias.id FROM users UserAlias", + }, + { + Name: "normalize snake_case alias", + Generated: true, + SQL: "SELECT order_item.id FROM users order_item", + Expect: "SELECT OrderItem.id FROM users OrderItem", + }, + } +} diff --git a/internal/translator/config.go b/internal/translator/config.go index d8adfd670..656b6f430 100644 --- a/internal/translator/config.go +++ b/internal/translator/config.go @@ -156,10 +156,10 @@ func (c *Config) NormalizeURL(repositoryURL string) { cfg.ContentURL = url.Join(baseURL, cfg.ContentURL) } if url.IsRelative(cfg.PluginsURL) { - cfg.RouteURL = url.Join(baseURL, cfg.PluginsURL) + cfg.PluginsURL = url.Join(baseURL, cfg.PluginsURL) } if url.IsRelative(cfg.DependencyURL) { - cfg.RouteURL = url.Join(baseURL, cfg.DependencyURL) + cfg.DependencyURL = url.Join(baseURL, cfg.DependencyURL) } cfg.URL = url.Join(baseURL, "config.json") } diff --git a/internal/translator/function.go b/internal/translator/function.go index 923250d59..e7b995bce 100644 --- a/internal/translator/function.go +++ b/internal/translator/function.go @@ -12,6 +12,8 @@ import ( "strings" ) +const privateColumnTag = `internal:"true" json:"-"` + // TODO introduce function abstraction for datly -h list funciton, with validation signtaure description func (n *Viewlets) applySettingFunctions(column *sqlparser.Column, namespace string) (bool, error) { funcName, funcArgs := extractFunction(column) @@ -54,16 +56,15 @@ func (n *Viewlets) applySettingFunctions(column *sqlparser.Column, namespace str if dest != nil { switch strings.ToLower(funcName) { case "tag": - if column.Name == column.Namespace && !strings.Contains(column.Expression, column.Name+"."+column.Name) { - if dest.View == nil { - dest.View = &View{} - } - dest.View.Tag = strings.Trim(column.Tag, "'") - return true, nil + if err := applyColumnTagSetting(dest, column); err != nil { + return false, err + } + return true, nil + case "private": + column.Tag = privateColumnTag + if err := applyColumnTagSetting(dest, column); err != nil { + return false, err } - columnConfig := dest.columnConfig(column.Name) - column.Tag = strings.Trim(strings.TrimSpace(column.Tag), "'") - columnConfig.Tag = &column.Tag return true, nil case "cast": return dest.applyExplicitCast(column, funcArgs) @@ -101,6 +102,20 @@ func (n *Viewlets) applySettingFunctions(column *sqlparser.Column, namespace str return true, nil } +func applyColumnTagSetting(dest *Viewlet, column *sqlparser.Column) error { + if column.Name == column.Namespace && !strings.Contains(column.Expression, column.Name+"."+column.Name) { + if dest.View == nil { + dest.View = &View{} + } + dest.View.Tag = strings.Trim(column.Tag, "'") + return nil + } + columnConfig := dest.columnConfig(column.Name) + column.Tag = strings.Trim(strings.TrimSpace(column.Tag), "'") + columnConfig.Tag = &column.Tag + return nil +} + func (v *Viewlet) applyExplicitCast(column *sqlparser.Column, funcArgs []string) (bool, error) { if column.Name == "" || column.Name == column.Namespace { if v.View.Schema == nil { @@ -124,7 +139,10 @@ func (v *Viewlet) applyExplicitCast(column *sqlparser.Column, funcArgs []string) column.Type = funcArgs[1] rType, err := types.LookupType(v.Resource.typeRegistry.Lookup, column.Type) if err != nil { - return false, fmt.Errorf("unknown column %v type: %s, %w", column.Name, column.Type, err) + // Keep unresolved custom cast as metadata only. This preserves declared type + // (e.g. *fee.Fee) for IR/yaml parity without forcing runtime type resolution. + // Built-in and resolvable types still set RawType. + return true, nil } column.RawType = rType return true, nil diff --git a/internal/translator/function/allowedorderbycolumn.go b/internal/translator/function/allowedorderbycolumn.go new file mode 100644 index 000000000..cfad0bbf7 --- /dev/null +++ b/internal/translator/function/allowedorderbycolumn.go @@ -0,0 +1,76 @@ +package function + +import ( + "fmt" + "strings" + + "github.com/viant/datly/view" + "github.com/viant/sqlparser" +) + +type allowedOrderByColumns struct{} + +func (c *allowedOrderByColumns) Apply(args []string, column *sqlparser.Column, resource *view.Resource, aView *view.View) error { + if aView.Selector == nil { + aView.Selector = &view.Config{} + } + values, err := convertArguments(c, args) + if err != nil { + return err + } + if aView.Selector.Constraints == nil { + aView.Selector.Constraints = &view.Constraints{} + } + aView.Selector.Constraints.OrderBy = true + if len(values) == 0 { + return fmt.Errorf("failed to discover expression in allowedOrderByColumns") + } + columns, ok := values[0].(string) + if !ok { + return fmt.Errorf("invalid columns type: %T, expected: %T in allowedOrderByColumns", values[0], columns) + } + if len(aView.Selector.Constraints.OrderByColumn) == 0 { + aView.Selector.Constraints.OrderByColumn = map[string]string{} + } + for _, expression := range strings.Split(columns, ",") { + expression = strings.TrimSpace(expression) + + key := expression + value := expression + if strings.Contains(expression, ":") { + parts := strings.SplitN(expression, ":", 2) + key = parts[0] + value = parts[1] + } + + aView.Selector.Constraints.OrderByColumn[key] = value + lcKey := strings.ToLower(key) + if lcKey != key { + aView.Selector.Constraints.OrderByColumn[lcKey] = value + } + + if index := strings.Index(key, "."); index != -1 { + aView.Selector.Constraints.OrderByColumn[key[index+1:]] = value + } + } + return nil +} + +func (c *allowedOrderByColumns) Name() string { + return "allowed_order_by_columns" +} + +func (c *allowedOrderByColumns) Description() string { + return "set view.Selector.OrderBy and enables corresponding view.Selector.Constraints.OrderBy" +} + +func (c *allowedOrderByColumns) Arguments() []*Argument { + return []*Argument{ + { + Name: "allowedOrderByColumns", + Description: "query selector allowedOrderByColumns", + Required: true, + DataType: "string", + }, + } +} diff --git a/internal/translator/function/function.go b/internal/translator/function/function.go index e02545186..f150f773a 100644 --- a/internal/translator/function/function.go +++ b/internal/translator/function/function.go @@ -69,7 +69,7 @@ func convertArguments(signature Signature, args []string) ([]interface{}, error) result = append(result, v) default: - return nil, fmt.Errorf("unsupported %v data type", argument.Name, argument.DataType) + return nil, fmt.Errorf("unsupported %v data type: %s", argument.Name, argument.DataType) } } return result, nil diff --git a/internal/translator/function/init.go b/internal/translator/function/init.go index b67d258f0..b0141c3d8 100644 --- a/internal/translator/function/init.go +++ b/internal/translator/function/init.go @@ -5,6 +5,7 @@ func init() { _registry.Register(&cache{}) _registry.Register(&limit{}) _registry.Register(&orderBy{}) + _registry.Register(&allowedOrderByColumns{}) _registry.Register(&cardinality{}) _registry.Register(&allownulls{}) _registry.Register(&matchStrategy{}) diff --git a/internal/translator/function/orderby.go b/internal/translator/function/orderby.go index 645be6058..9dc22f5ca 100644 --- a/internal/translator/function/orderby.go +++ b/internal/translator/function/orderby.go @@ -20,6 +20,7 @@ func (c *orderBy) Apply(args []string, column *sqlparser.Column, resource *view. } aView.Selector.Constraints.OrderBy = true aView.Selector.OrderBy = values[0].(string) + return nil } diff --git a/internal/translator/output.go b/internal/translator/output.go index 6c562649d..368c21ff4 100644 --- a/internal/translator/output.go +++ b/internal/translator/output.go @@ -3,6 +3,9 @@ package translator import ( "context" "fmt" + "reflect" + "strings" + "github.com/viant/datly/internal/inference" "github.com/viant/datly/internal/setter" "github.com/viant/datly/repository/contract" @@ -21,8 +24,6 @@ import ( "github.com/viant/xdatly/handler/response/tabular/xml" "github.com/viant/xdatly/predicate" "github.com/viant/xreflect" - "reflect" - "strings" ) func (s *Service) updateOutputParameters(resource *Resource, rootViewlet *Viewlet) (err error) { @@ -54,6 +55,9 @@ func (s *Service) updateOutputParameters(resource *Resource, rootViewlet *Viewle outputParameters := s.ensureOutputParameters(resource, resource.OutputState) dataParameter := outputParameters.LookupByLocation(state.KindOutput, keys.ViewData) if dataParameter != nil { + if rootViewlet.View.Schema == nil { + return fmt.Errorf("view %s has no detected schema; ensure column discovery succeeded (connector: %s)", rootViewlet.Name, rootViewlet.GetConnector()) + } s.updateParameterWithComponentOutputType(dataParameter, rootViewlet) } @@ -452,6 +456,12 @@ func (s *Service) ensureOutputParameters(resource *Resource, outputState inferen } func (s *Service) updateParameterWithComponentOutputType(dataParameter *state.Parameter, rootViewlet *Viewlet) { + if rootViewlet == nil || rootViewlet.View == nil || rootViewlet.Resource == nil || rootViewlet.Resource.rule == nil { + return + } + if rootViewlet.View.Schema == nil { + rootViewlet.View.Schema = &state.Schema{} + } typeName := rootViewlet.View.Schema.Name if typeName == "" || typeName == "string" { typeName = view.DefaultTypeName(rootViewlet.Name) diff --git a/internal/translator/parser/declarations.go b/internal/translator/parser/declarations.go index ca11710de..69b017ec2 100644 --- a/internal/translator/parser/declarations.go +++ b/internal/translator/parser/declarations.go @@ -2,6 +2,10 @@ package parser import ( "fmt" + "reflect" + "strconv" + "strings" + "github.com/viant/datly/gateway/router/marshal" "github.com/viant/datly/internal/inference" "github.com/viant/datly/shared" @@ -12,9 +16,6 @@ import ( "github.com/viant/velty/ast/expr" "github.com/viant/velty/parser" "github.com/viant/xreflect" - "reflect" - "strconv" - "strings" ) type ( @@ -146,7 +147,11 @@ func (d *Declarations) parseExpression(cursor *parsly.Cursor, selector *expr.Sel declaration.Kind = segments[0] location := "" if len(segments) > 1 { - location = strings.Join(segments[1:], ".") + joiner := "." + if declaration.Kind == string(state.KindComponent) { + joiner = "/" + } + location = strings.Join(segments[1:], joiner) } declaration.Location = &location declaration.InOutput = declaration.Kind == string(state.KindOutput) @@ -206,7 +211,7 @@ func (d *Declarations) tryParseTypeExpression(typeContent string, declaration *D dataType = strings.Replace(dataType, typeName, "interface{}", 1) } - if dataType != "" { + if dataType != "" && d.lookup != nil { if schema, _ := d.lookup(dataType); schema != nil { schema.Cardinality = declaration.Cardinality if rType := schema.Type(); rType != nil && schema.Cardinality == state.Many { @@ -322,6 +327,10 @@ func (s *Declarations) parseShorthands(declaration *Declaration, cursor *parsly. declaration.InOutput = true case "WithCache": declaration.Cache = strings.Trim(args[0], `"'`) + case "WithLimit": + limit, _ := strconv.Atoi(strings.Trim(args[0], `"'`)) + declaration.Limit = &limit + case "Cacheable": literal := strings.Trim(args[0], `"'`) value, _ := strconv.ParseBool(literal) diff --git a/internal/translator/parser/declarations_test.go b/internal/translator/parser/declarations_test.go index 58489bba0..5694acbdd 100644 --- a/internal/translator/parser/declarations_test.go +++ b/internal/translator/parser/declarations_test.go @@ -31,10 +31,71 @@ SELECT 1 FROM t WHERE ID IN($TeamIDs) Kind: state.KindQuery, Name: "tids", }, - Output: &state.Codec{Name: "AsInts"}, + Output: &state.Codec{Name: "AsInts", Args: []string{}}, Schema: &state.Schema{ Cardinality: state.One, + DataType: "string", }, + Required: &[]bool{false}[0], + }, + + ModificationSetting: inference.ModificationSetting{}, + SQL: "", + Hint: "", + }, + }, + }, + { + description: "Query string param with #define alias", + DSQL: ` +#define($_ = $TeamIDs(query/tids).WithCodec(AsInts)) +SELECT 1 FROM t WHERE ID IN($TeamIDs) +`, + expectedSQL: `SELECT 1 FROM t WHERE ID IN($TeamIDs)`, + expectedState: inference.State{ + &inference.Parameter{ + Explicit: true, + Parameter: state.Parameter{ + Name: "TeamIDs", + In: &state.Location{ + Kind: state.KindQuery, + Name: "tids", + }, + Output: &state.Codec{Name: "AsInts", Args: []string{}}, + Schema: &state.Schema{ + Cardinality: state.One, + DataType: "string", + }, + Required: &[]bool{false}[0], + }, + ModificationSetting: inference.ModificationSetting{}, + SQL: "", + Hint: "", + }, + }, + }, + { + description: "Query string param with #settings alias", + DSQL: ` +#settings($_ = $TeamIDs(query/tids).WithCodec(AsInts)) +SELECT 1 FROM t WHERE ID IN($TeamIDs) +`, + expectedSQL: `SELECT 1 FROM t WHERE ID IN($TeamIDs)`, + expectedState: inference.State{ + &inference.Parameter{ + Explicit: true, + Parameter: state.Parameter{ + Name: "TeamIDs", + In: &state.Location{ + Kind: state.KindQuery, + Name: "tids", + }, + Output: &state.Codec{Name: "AsInts", Args: []string{}}, + Schema: &state.Schema{ + Cardinality: state.One, + DataType: "string", + }, + Required: &[]bool{false}[0], }, ModificationSetting: inference.ModificationSetting{}, diff --git a/internal/translator/parser/lex.go b/internal/translator/parser/lex.go index 020aa0da0..1cab3a790 100644 --- a/internal/translator/parser/lex.go +++ b/internal/translator/parser/lex.go @@ -60,8 +60,8 @@ const ( var whitespaceMatcher = parsly.NewToken(whitespaceToken, "Whitespace", matcher.NewWhiteSpace()) var exprGroupMatcher = parsly.NewToken(exprGroupToken, "( .... )", matcher.NewBlock('(', ')', '\\')) -var setTerminatedMatcher = parsly.NewToken(setTerminatedToken, "#set", imatchers.NewStringTerminator("#set")) -var setMatcher = parsly.NewToken(setToken, "#set", matcher.NewFragments([]byte("#set"))) +var setTerminatedMatcher = parsly.NewToken(setTerminatedToken, "#set/#define/#settings", imatchers.NewAnyStringTerminator("#set", "#define", "#settings")) +var setMatcher = parsly.NewToken(setToken, "#set", matcher.NewFragments([]byte("#settings"), []byte("#define"), []byte("#set"))) var parameterDeclarationMatcher = parsly.NewToken(parameterDeclarationToken, "$_", matcher.NewSpacedSet([]string{"$_ = $"})) var commentMatcher = parsly.NewToken(commentToken, "/**/", matcher.NewSeqBlock("/*", "*/")) var typeMatcher = parsly.NewToken(typeToken, "", matcher.NewSeqBlock("<", ">")) @@ -70,7 +70,7 @@ var selectMatcher = parsly.NewToken(selectToken, "Applier call", imatchers.NewId var execStmtMatcher = parsly.NewToken(execStmtToken, "Exec statement", matcher.NewFragmentsFold([]byte("insert"), []byte("update"), []byte("delete"), []byte("call"), []byte("begin"))) var readStmtMatcher = parsly.NewToken(readStmtToken, "Select statement", matcher.NewFragmentsFold([]byte("select"))) -var exprMatcher = parsly.NewToken(exprToken, "Expression", matcher.NewFragments([]byte("#set"), []byte("#foreach"), []byte("#if"))) +var exprMatcher = parsly.NewToken(exprToken, "Expression", matcher.NewFragments([]byte("#settings"), []byte("#define"), []byte("#set"), []byte("#foreach"), []byte("#if"))) var anyMatcher = parsly.NewToken(anyToken, "Any", imatchers.NewAny()) var exprEndMatcher = parsly.NewToken(exprEndToken, "#end", matcher.NewFragmentsFold([]byte("#end"))) @@ -91,7 +91,7 @@ var ParenthesesBlockMatcher = parsly.NewToken(ParenthesesBlockToken, "Parenthese var endMatcher = parsly.NewToken(endToken, "End", matcher.NewFragment("#end")) var elseMatcher = parsly.NewToken(elseToken, "Else", matcher.NewFragment("#else")) var elseIfMatcher = parsly.NewToken(elseToken, "ElseIf", matcher.NewFragment("#elseif")) -var assignMatcher = parsly.NewToken(assignToken, "Set", matcher.NewFragment("#set")) +var assignMatcher = parsly.NewToken(assignToken, "Set", matcher.NewFragments([]byte("#settings"), []byte("#define"), []byte("#set"))) var forEachMatcher = parsly.NewToken(forEachToken, "ForEach", matcher.NewFragment("#foreach")) var ifMatcher = parsly.NewToken(ifToken, "If", matcher.NewFragment("#if")) diff --git a/internal/translator/parser/matchers/terminator.go b/internal/translator/parser/matchers/terminator.go index d94865f7e..fb133c6ee 100644 --- a/internal/translator/parser/matchers/terminator.go +++ b/internal/translator/parser/matchers/terminator.go @@ -9,6 +9,10 @@ type stringTerminatorMatcher struct { value []byte } +type anyStringTerminatorMatcher struct { + values [][]byte +} + func (t *stringTerminatorMatcher) Match(cursor *parsly.Cursor) (matched int) { if len(t.value) >= cursor.InputSize-cursor.Pos { return 0 @@ -25,6 +29,32 @@ func (t *stringTerminatorMatcher) Match(cursor *parsly.Cursor) (matched int) { return 0 } +func (t *anyStringTerminatorMatcher) Match(cursor *parsly.Cursor) (matched int) { + for i := cursor.Pos; i < cursor.InputSize; i++ { + for _, value := range t.values { + if len(value) == 0 || len(value) > cursor.InputSize-i { + continue + } + if bytes.Equal(cursor.Input[i:i+len(value)], value) { + return matched + } + } + matched++ + } + return 0 +} + func NewStringTerminator(by string) *stringTerminatorMatcher { return &stringTerminatorMatcher{value: []byte(by)} } + +func NewAnyStringTerminator(values ...string) *anyStringTerminatorMatcher { + ret := &anyStringTerminatorMatcher{} + for _, value := range values { + if value == "" { + continue + } + ret.values = append(ret.values, []byte(value)) + } + return ret +} diff --git a/internal/translator/parser/sanitizer_test.go b/internal/translator/parser/sanitizer_test.go new file mode 100644 index 000000000..38a9c8d1e --- /dev/null +++ b/internal/translator/parser/sanitizer_test.go @@ -0,0 +1,222 @@ +package parser + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/viant/datly/internal/inference" + "github.com/viant/datly/view/keywords" + "github.com/viant/velty/functions" +) + +func TestTemplate_Sanitize(t *testing.T) { + state := inference.State{} + tmpl, err := NewTemplate("#set($x = 1) SELECT * FROM t WHERE id = $x AND name = $Name", &state) + require.NoError(t, err) + actual := tmpl.Sanitize() + assert.Contains(t, actual, "#set($x = 1)") + assert.Contains(t, actual, "$criteria.AppendBinding($x)") + assert.Contains(t, actual, "$criteria.AppendBinding($Unsafe.Name)") +} + +func TestSanitize_SkipsFirstSetVariableOccurrence(t *testing.T) { + iter := newIterable(map[string]bool{"x": true}) + expr := &Expression{ + IsVariable: true, + OccurrenceIndex: 0, + Context: SetContext, + FullName: "$x", + Start: 0, + End: 2, + } + dst := []byte("$x") + actual, _ := sanitize(iter, expr, dst, 0, 0) + assert.Equal(t, "$x", string(actual)) +} + +func TestUnwrapBrackets(t *testing.T) { + raw, had := unwrapBrackets("${Foo}") + assert.Equal(t, "$Foo", raw) + assert.True(t, had) + + raw, had = unwrapBrackets("$Foo") + assert.Equal(t, "$Foo", raw) + assert.False(t, had) +} + +func TestSanitizeContent(t *testing.T) { + iter := newIterable(nil) + expr := &Expression{Start: 0, End: 10} + assert.Equal(t, "$A", sanitizeContent(iter, expr, "$A")) + + iter = newIterable(nil) + parent := &Expression{Start: 0, End: 13, FullName: "$Fn($A, $B)"} + argA := &Expression{Start: 4, End: 6, FullName: "$A", Holder: "A"} + argB := &Expression{Start: 8, End: 10, FullName: "$B", Holder: "B"} + next := &Expression{Start: 20, End: 22, FullName: "$C", Holder: "C"} + iter.expressions = Expressions{argA, argB, next} + actual := sanitizeContent(iter, parent, parent.FullName) + assert.Equal(t, "$Fn($criteria.AppendBinding($Unsafe.A), $criteria.AppendBinding($Unsafe.B))", actual) +} + +func TestSanitizeParameter(t *testing.T) { + t.Run("standalone fn entry is preserved", func(t *testing.T) { + name := "TestStandaloneSanitize" + keywords.Add(name, functions.NewEntry(nil, &keywords.StandaloneFn{})) + iter := newIterable(nil) + expr := &Expression{Holder: name, FullName: "$" + name + "(1)"} + assert.Equal(t, "$"+name+"(1)", sanitizeParameter(expr, "$"+name+"(1)", iter, nil, 0)) + }) + + t.Run("set marker prefix preserved", func(t *testing.T) { + iter := newIterable(nil) + expr := &Expression{Holder: "Value", Prefix: keywords.SetMarkerKey} + assert.Equal(t, "$Value", sanitizeParameter(expr, "$Value", iter, nil, 0)) + }) + + t.Run("namespace metadata preserved", func(t *testing.T) { + iter := newIterable(nil) + expr := &Expression{ + Holder: "Any", + Entry: functions.NewEntry(nil, keywords.NewNamespace()), + } + assert.Equal(t, "$Any", sanitizeParameter(expr, "$Any", iter, nil, 0)) + }) + + t.Run("const parameter gets Unsafe prefix", func(t *testing.T) { + iter := newIterable(nil, inference.NewConstParameter("ConstX", 1)) + expr := &Expression{Holder: "ConstX"} + assert.Equal(t, "$Unsafe.ConstX", sanitizeParameter(expr, "$ConstX", iter, nil, 0)) + }) + + t.Run("func context with variable and Params prefix strips prefix", func(t *testing.T) { + iter := newIterable(map[string]bool{"X": true}) + expr := &Expression{Holder: "X", Prefix: keywords.ParamsKey, Context: FuncContext} + assert.Equal(t, "$X", sanitizeParameter(expr, "$Unsafe.X", iter, nil, 0)) + }) + + t.Run("func context with non variable and empty prefix adds Unsafe", func(t *testing.T) { + iter := newIterable(nil) + expr := &Expression{Holder: "X", Prefix: "", Context: FuncContext} + assert.Equal(t, "$Unsafe.X", sanitizeParameter(expr, "$X", iter, nil, 0)) + }) + + t.Run("func context with variable and custom prefix keeps raw", func(t *testing.T) { + iter := newIterable(map[string]bool{"X": true}) + expr := &Expression{Holder: "X", Prefix: keywords.AndPrefix, Context: ForEachContext} + assert.Equal(t, "$X", sanitizeParameter(expr, "$X", iter, nil, 0)) + }) + + t.Run("func context with non variable and non empty prefix keeps raw", func(t *testing.T) { + iter := newIterable(nil) + expr := &Expression{Holder: "X", Prefix: keywords.OrPrefix, Context: SetContext} + assert.Equal(t, "$X", sanitizeParameter(expr, "$X", iter, nil, 0)) + }) + + t.Run("func context with expression entry preserves raw", func(t *testing.T) { + iter := newIterable(nil) + expr := &Expression{Holder: "X", Context: IfContext, Entry: functions.NewEntry(nil, nil)} + assert.Equal(t, "$X", sanitizeParameter(expr, "$X", iter, nil, 0)) + }) + + t.Run("append context variable with Params prefix strips prefix", func(t *testing.T) { + iter := newIterable(map[string]bool{"X": true}) + expr := &Expression{Holder: "X", Prefix: keywords.ParamsKey} + assert.Equal(t, "$X", sanitizeParameter(expr, "$Unsafe.X", iter, nil, 0)) + }) + + t.Run("append context variable placeholder", func(t *testing.T) { + iter := newIterable(map[string]bool{"X": true}) + expr := &Expression{Holder: "X"} + assert.Equal(t, "$criteria.AppendBinding($X)", sanitizeParameter(expr, "$X", iter, nil, 0)) + }) + + t.Run("append context params prefix preserved", func(t *testing.T) { + iter := newIterable(nil) + expr := &Expression{Holder: "X", Prefix: keywords.ParamsKey} + assert.Equal(t, "$Unsafe.X", sanitizeParameter(expr, "$Unsafe.X", iter, nil, 0)) + }) + + t.Run("context metadata unexpand raw preserved", func(t *testing.T) { + iter := newIterable(nil) + expr := &Expression{ + Holder: "Ctx", + Entry: functions.NewEntry(nil, keywords.NewContextMetadata("ctx", nil, true)), + } + assert.Equal(t, "$Ctx", sanitizeParameter(expr, "$Ctx", iter, nil, 0)) + }) + + t.Run("context metadata expandable becomes placeholder", func(t *testing.T) { + iter := newIterable(nil) + expr := &Expression{ + Holder: "Ctx", + Entry: functions.NewEntry(nil, keywords.NewContextMetadata("ctx", nil, false)), + } + assert.Equal(t, "$criteria.AppendBinding($Ctx)", sanitizeParameter(expr, "$Ctx", iter, nil, 0)) + }) + + t.Run("non context metadata entry becomes placeholder", func(t *testing.T) { + iter := newIterable(nil) + expr := &Expression{ + Holder: "Ctx", + Entry: functions.NewEntry(nil, struct{}{}), + } + assert.Equal(t, "$criteria.AppendBinding($Ctx)", sanitizeParameter(expr, "$Ctx", iter, nil, 0)) + }) + + t.Run("default path adds Unsafe and placeholder", func(t *testing.T) { + iter := newIterable(nil) + expr := &Expression{Holder: "X"} + assert.Equal(t, "$criteria.AppendBinding($Unsafe.X)", sanitizeParameter(expr, "$X", iter, nil, 0)) + }) +} + +func TestSanitizeAsPlaceholder(t *testing.T) { + assert.Equal(t, "$criteria.AppendBinding($X)", sanitizeAsPlaceholder("$X")) +} + +func TestSanitize_WithBracketsWrapping(t *testing.T) { + iter := newIterable(nil) + expr := &Expression{ + FullName: "${X}", + Holder: "X", + Start: 0, + End: 4, + } + dst := []byte("${X}") + actual, _ := sanitize(iter, expr, dst, 0, 0) + assert.Equal(t, "${criteria.AppendBinding($Unsafe.X)}", string(actual)) +} + +func TestSanitize_NoChangePathAndCursorOffset(t *testing.T) { + iter := newIterable(nil) + expr := &Expression{ + FullName: "$Unsafe.X", + Holder: "X", + Prefix: keywords.ParamsKey, + Start: 8, + End: 17, + } + dst := []byte("SELECT " + expr.FullName) + actual, offset := sanitize(iter, expr, dst, 0, 7) + assert.Equal(t, "SELECT $Unsafe.X", string(actual)) + assert.Equal(t, 0, offset) +} + +func newIterable(declared map[string]bool, params ...*inference.Parameter) *iterables { + if declared == nil { + declared = map[string]bool{} + } + state := inference.State{} + for _, param := range params { + if param != nil { + state.Append(param) + } + } + tmpl := &Template{ + Declared: declared, + State: &state, + } + return &iterables{expressionMatcher: &expressionMatcher{Template: tmpl}} +} diff --git a/internal/translator/parser/statement.go b/internal/translator/parser/statement.go index 3874f4c48..39633964f 100644 --- a/internal/translator/parser/statement.go +++ b/internal/translator/parser/statement.go @@ -38,42 +38,49 @@ func (s Statements) DMLTables(rawSQL string) []string { var tables = make(map[string]bool) var result []string for _, statement := range s { + // Only consider exec statements for DML table extraction. + if !statement.IsExec { + continue + } SQL := rawSQL[statement.Start:statement.End] - usesService := strings.Contains(SQL, "$sql.") - lowerCasedDML := strings.ToLower(SQL) - quoted := "" - - if index := strings.Index(SQL, `"`); index != -1 { - quoted = SQL[index+1:] - if index = strings.Index(quoted, `"`); index != -1 { - quoted = quoted[:index] + // Handle service-based exec ($sql.Insert/$sql.Update) only when explicitly detected as service. + if statement.Kind == shared.ExecKindService { + quoted := "" + if index := strings.Index(SQL, `"`); index != -1 { + quoted = SQL[index+1:] + if index = strings.Index(quoted, `"`); index != -1 { + quoted = quoted[:index] + } } - } - if usesService && quoted != "" { - statement.Table = quoted - if _, ok := tables[statement.Table]; ok { + if quoted != "" { + statement.Table = quoted + if _, ok := tables[statement.Table]; ok { + continue + } + result = append(result, statement.Table) + tables[statement.Table] = true continue } - result = append(result, statement.Table) - tables[statement.Table] = true - continue } + + lowerCasedDML := strings.ToLower(SQL) + if strings.Contains(lowerCasedDML, "insert") { - if stmt, _ := sqlparser.ParseInsert(SQL); stmt != nil { + if stmt, _ := sqlparser.ParseInsert(SQL); stmt != nil && stmt.Target.X != nil { if table := sqlparser.Stringify(stmt.Target.X); table != "" { statement.Table = table } } } else if strings.Contains(lowerCasedDML, "update") { - if stmt, _ := sqlparser.ParseUpdate(SQL); stmt != nil { + if stmt, _ := sqlparser.ParseUpdate(SQL); stmt != nil && stmt.Target.X != nil { if table := sqlparser.Stringify(stmt.Target.X); table != "" { statement.Table = table } } } else if strings.Contains(lowerCasedDML, "delete") { - if stmt, _ := sqlparser.ParseDelete(SQL); stmt != nil { + if stmt, _ := sqlparser.ParseDelete(SQL); stmt != nil && stmt.Target.X != nil { if table := sqlparser.Stringify(stmt.Target.X); table != "" { statement.Table = table } diff --git a/internal/translator/resource.go b/internal/translator/resource.go index eb92605b8..bf63f33ef 100644 --- a/internal/translator/resource.go +++ b/internal/translator/resource.go @@ -3,6 +3,12 @@ package translator import ( "context" "fmt" + "net/http" + "path" + "reflect" + "regexp" + "strings" + "github.com/viant/afs" "github.com/viant/afs/url" "github.com/viant/datly/cmd/options" @@ -10,6 +16,7 @@ import ( "github.com/viant/datly/internal/msg" "github.com/viant/datly/internal/setter" tparser "github.com/viant/datly/internal/translator/parser" + "github.com/viant/datly/repository/content" expand "github.com/viant/datly/service/executor/expand" "github.com/viant/datly/shared" "github.com/viant/datly/utils/types" @@ -22,11 +29,40 @@ import ( "github.com/viant/toolbox" "github.com/viant/xreflect" "golang.org/x/mod/modfile" - "path" - "reflect" - "strings" ) +var ( + routeSettingsLineExpr = regexp.MustCompile(`(?im)^\s*#(?:settings|define|set)\s*\(\s*\$_\s*=\s*\$route\s*\(([^)]*)\)\s*\)\s*$`) + packageLineExpr = regexp.MustCompile(`(?im)^\s*#package\s*\(\s*['"]([^'"]+)['"]\s*\)\s*$`) + hashImportLineExpr = regexp.MustCompile(`(?im)^\s*#import\s*\(([^)]*)\)\s*$`) + connectorSettingsLineExpr = regexp.MustCompile(`(?im)^\s*#(?:settings|define|set)\s*\(\s*\$_\s*=\s*\$connector\s*\(([^)]*)\)\s*\)\s*$`) + handlerSettingsLineExpr = regexp.MustCompile(`(?im)^\s*#(?:settings|define|set)\s*\(\s*\$_\s*=\s*\$handler\s*\(([^)]*)\)\s*\)\s*$`) + inputSettingsLineExpr = regexp.MustCompile(`(?im)^\s*#(?:settings|define|set)\s*\(\s*\$_\s*=\s*\$input\s*\(([^)]*)\)\s*\)\s*$`) + outputSettingsLineExpr = regexp.MustCompile(`(?im)^\s*#(?:settings|define|set)\s*\(\s*\$_\s*=\s*\$output\s*\(([^)]*)\)\s*\)\s*$`) + marshalSettingsLineExpr = regexp.MustCompile(`(?im)^\s*#(?:settings|define|set)\s*\(\s*\$_\s*=\s*\$marshal\s*\(\s*['"]([^'"]+)['"]\s*,\s*['"]([^'"]+)['"]\s*\)\s*\)\s*$`) + unmarshalSettingsLineExpr = regexp.MustCompile(`(?im)^\s*#(?:settings|define|set)\s*\(\s*\$_\s*=\s*\$unmarshal\s*\(\s*['"]([^'"]+)['"]\s*,\s*['"]([^'"]+)['"]\s*\)\s*\)\s*$`) + formatSettingsLineExpr = regexp.MustCompile(`(?im)^\s*#(?:settings|define|set)\s*\(\s*\$_\s*=\s*\$format\s*\(\s*['"]([^'"]+)['"]\s*\)\s*\)\s*$`) + dateFormatSettingsLineExpr = regexp.MustCompile(`(?im)^\s*#(?:settings|define|set)\s*\(\s*\$_\s*=\s*\$date_format\s*\(\s*['"]([^'"]+)['"]\s*\)\s*\)\s*$`) + caseFormatSettingsLineExpr = regexp.MustCompile(`(?im)^\s*#(?:settings|define|set)\s*\(\s*\$_\s*=\s*\$case_format\s*\(\s*['"]([^'"]+)['"]\s*\)\s*\)\s*$`) + quotedArgExpr = regexp.MustCompile(`['"]([^'"]*)['"]`) +) + +type routeSettingsDirective struct { + URI string + Methods []string + Package string + Connector string + HandlerType string + InputType string + OutputType string + JSONMarshalType string + JSONUnmarshalType string + XMLUnmarshalType string + Format string + DateFormat string + CaseFormat string +} + type ( Resource struct { Generated bool @@ -143,6 +179,7 @@ func (r *Resource) ensureRegistry() *xreflect.Types { } func (r *Resource) parseImports(ctx context.Context, dSQL *string) (err error) { + *dSQL = removeHashImportDirectives(*dSQL) if r.Rule.TypeSrc != nil { if err = r.loadImportTypes(ctx, r.Rule.TypeSrc); err != nil { return err @@ -352,6 +389,15 @@ func (r *Resource) buildParameterViews() { if parameter.Cache != "" { viewlet.View.Cache = &view.Cache{Reference: shared.Reference{Ref: parameter.Cache}} } + if parameter.Limit != nil { + if viewlet.View.Selector == nil { + viewlet.View.Selector = &view.Config{ + Constraints: &view.Constraints{Limit: true}, + } + } + viewlet.View.Selector.Limit = *parameter.Limit + viewlet.View.Selector.NoLimit = viewlet.View.Selector.Limit == 0 + } if viewlet.Connector == "" { viewlet.Connector = r.rootConnector } @@ -398,6 +444,50 @@ func (r *Resource) extractRuleSetting(dSQL *string) error { } *dSQL = (*dSQL)[index+2:] } + if directive, ok, err := parseSettingsDirectives(*dSQL); err != nil { + return err + } else if ok { + if directive.URI != "" { + r.Rule.URI = directive.URI + } + if len(directive.Methods) > 0 { + r.Rule.Method = strings.Join(directive.Methods, ",") + } + if directive.Package != "" { + r.Rule.Package = directive.Package + } + if directive.Connector != "" { + r.Rule.Connector = directive.Connector + } + if directive.HandlerType != "" { + r.Rule.Type = qualifyTypeWithPackage(directive.HandlerType, r.Rule.Package) + } + if directive.InputType != "" { + r.Rule.InputType = qualifyTypeWithPackage(directive.InputType, r.Rule.Package) + } + if directive.OutputType != "" { + r.Rule.OutputType = qualifyTypeWithPackage(directive.OutputType, r.Rule.Package) + } + if directive.JSONMarshalType != "" { + r.Rule.JSONMarshalType = qualifyTypeWithPackage(directive.JSONMarshalType, r.Rule.Package) + } + if directive.JSONUnmarshalType != "" { + r.Rule.JSONUnmarshalType = qualifyTypeWithPackage(directive.JSONUnmarshalType, r.Rule.Package) + } + if directive.XMLUnmarshalType != "" { + r.Rule.XMLUnmarshalType = qualifyTypeWithPackage(directive.XMLUnmarshalType, r.Rule.Package) + } + if directive.Format != "" { + r.Rule.DataFormat = directive.Format + } + if directive.DateFormat != "" { + r.Rule.Route.Content.DateFormat = directive.DateFormat + } + if directive.CaseFormat != "" { + r.Rule.Route.Output.CaseFormat = text.CaseFormat(directive.CaseFormat) + } + *dSQL = removeSettingsDirectives(*dSQL) + } r.Rule.applyShortHands() if r.Rule.Connector != "" { r.rule.Connector = r.Rule.Connector @@ -406,6 +496,275 @@ func (r *Resource) extractRuleSetting(dSQL *string) error { return nil } +func parseSettingsDirectives(dSQL string) (*routeSettingsDirective, bool, error) { + ret := &routeSettingsDirective{} + var found bool + matches := packageLineExpr.FindAllStringSubmatch(dSQL, -1) + if len(matches) > 0 { + found = true + last := matches[len(matches)-1] + if len(last) < 2 || strings.TrimSpace(last[1]) == "" { + return nil, false, fmt.Errorf("invalid #package directive") + } + ret.Package = strings.TrimSpace(last[1]) + } + matches = routeSettingsLineExpr.FindAllStringSubmatch(dSQL, -1) + if len(matches) > 0 { + found = true + last := matches[len(matches)-1] + if len(last) < 2 { + return nil, false, fmt.Errorf("invalid $route directive") + } + args := parseQuotedArgs(last[1]) + if len(args) == 0 { + return nil, false, fmt.Errorf("invalid $route directive: missing URI") + } + URI := strings.TrimSpace(args[0]) + if !strings.HasPrefix(URI, "/") { + return nil, false, fmt.Errorf("invalid $route directive: URI must start with /") + } + methods, err := normalizeRouteMethods(args[1:]) + if err != nil { + return nil, false, err + } + ret.URI = URI + ret.Methods = methods + } + + matches = connectorSettingsLineExpr.FindAllStringSubmatch(dSQL, -1) + if len(matches) > 0 { + found = true + last := matches[len(matches)-1] + value := parseSingleArg(last) + if value == "" { + return nil, false, fmt.Errorf("invalid $connector directive") + } + ret.Connector = value + } + + matches = handlerSettingsLineExpr.FindAllStringSubmatch(dSQL, -1) + if len(matches) > 0 { + found = true + last := matches[len(matches)-1] + value := parseSingleArg(last) + if value == "" { + return nil, false, fmt.Errorf("invalid $handler directive") + } + ret.HandlerType = value + } + + matches = inputSettingsLineExpr.FindAllStringSubmatch(dSQL, -1) + if len(matches) > 0 { + found = true + last := matches[len(matches)-1] + value := parseSingleArg(last) + if value == "" { + return nil, false, fmt.Errorf("invalid $input directive") + } + ret.InputType = value + } + + matches = outputSettingsLineExpr.FindAllStringSubmatch(dSQL, -1) + if len(matches) > 0 { + found = true + last := matches[len(matches)-1] + value := parseSingleArg(last) + if value == "" { + return nil, false, fmt.Errorf("invalid $output directive") + } + ret.OutputType = value + } + + matches = marshalSettingsLineExpr.FindAllStringSubmatch(dSQL, -1) + if len(matches) > 0 { + found = true + last := matches[len(matches)-1] + if len(last) < 3 { + return nil, false, fmt.Errorf("invalid $marshal directive") + } + mimeType := strings.ToLower(strings.TrimSpace(last[1])) + if mimeType != content.JSONContentType { + return nil, false, fmt.Errorf("invalid $marshal directive: unsupported mime type %q", mimeType) + } + typeName := strings.TrimSpace(last[2]) + if typeName == "" { + return nil, false, fmt.Errorf("invalid $marshal directive: missing type") + } + ret.JSONMarshalType = typeName + } + + matches = unmarshalSettingsLineExpr.FindAllStringSubmatch(dSQL, -1) + if len(matches) > 0 { + found = true + for _, match := range matches { + if len(match) < 3 { + return nil, false, fmt.Errorf("invalid $unmarshal directive") + } + mimeType := strings.ToLower(strings.TrimSpace(match[1])) + typeName := strings.TrimSpace(match[2]) + if typeName == "" { + return nil, false, fmt.Errorf("invalid $unmarshal directive: missing type") + } + switch mimeType { + case content.JSONContentType: + ret.JSONUnmarshalType = typeName + case content.XMLContentType: + ret.XMLUnmarshalType = typeName + default: + return nil, false, fmt.Errorf("invalid $unmarshal directive: unsupported mime type %q", mimeType) + } + } + } + + matches = formatSettingsLineExpr.FindAllStringSubmatch(dSQL, -1) + if len(matches) > 0 { + found = true + last := matches[len(matches)-1] + if len(last) < 2 { + return nil, false, fmt.Errorf("invalid $format directive") + } + raw := strings.ToLower(strings.TrimSpace(last[1])) + switch raw { + case "tabular_json": + ret.Format = content.JSONDataFormatTabular + case content.JSONFormat, content.XMLFormat, content.CSVFormat, content.JSONDataFormatTabular: + ret.Format = raw + default: + return nil, false, fmt.Errorf("invalid $format directive: unsupported format %q", raw) + } + } + + matches = dateFormatSettingsLineExpr.FindAllStringSubmatch(dSQL, -1) + if len(matches) > 0 { + found = true + last := matches[len(matches)-1] + if len(last) < 2 || strings.TrimSpace(last[1]) == "" { + return nil, false, fmt.Errorf("invalid $date_format directive") + } + ret.DateFormat = strings.TrimSpace(last[1]) + } + + matches = caseFormatSettingsLineExpr.FindAllStringSubmatch(dSQL, -1) + if len(matches) > 0 { + found = true + last := matches[len(matches)-1] + if len(last) < 2 || strings.TrimSpace(last[1]) == "" { + return nil, false, fmt.Errorf("invalid $case_format directive") + } + caseFormat := strings.TrimSpace(last[1]) + if !text.NewCaseFormat(caseFormat).IsDefined() { + return nil, false, fmt.Errorf("invalid $case_format directive: unsupported case format %q", caseFormat) + } + ret.CaseFormat = caseFormat + } + return ret, found, nil +} + +func parseQuotedArgs(input string) []string { + matches := quotedArgExpr.FindAllStringSubmatch(input, -1) + result := make([]string, 0, len(matches)) + for _, match := range matches { + if len(match) < 2 { + continue + } + result = append(result, strings.TrimSpace(match[1])) + } + return result +} + +func parseSingleArg(match []string) string { + if len(match) < 2 { + return "" + } + value := strings.TrimSpace(match[1]) + value = strings.Trim(value, `"'`) + return strings.TrimSpace(value) +} + +func normalizeRouteMethods(input []string) ([]string, error) { + if len(input) == 0 { + return nil, nil + } + valid := map[string]bool{ + http.MethodGet: true, + http.MethodPost: true, + http.MethodPut: true, + http.MethodPatch: true, + http.MethodDelete: true, + http.MethodHead: true, + http.MethodOptions: true, + http.MethodTrace: true, + http.MethodConnect: true, + } + seen := map[string]bool{} + result := make([]string, 0, len(input)) + for _, item := range input { + method := strings.ToUpper(strings.TrimSpace(item)) + if method == "" { + return nil, fmt.Errorf("invalid $route directive: empty method") + } + if !valid[method] { + return nil, fmt.Errorf("invalid $route directive: unsupported method %q", method) + } + if seen[method] { + continue + } + seen[method] = true + result = append(result, method) + } + return result, nil +} + +func removeSettingsDirectives(dSQL string) string { + dSQL = packageLineExpr.ReplaceAllString(dSQL, "") + dSQL = routeSettingsLineExpr.ReplaceAllString(dSQL, "") + dSQL = connectorSettingsLineExpr.ReplaceAllString(dSQL, "") + dSQL = handlerSettingsLineExpr.ReplaceAllString(dSQL, "") + dSQL = inputSettingsLineExpr.ReplaceAllString(dSQL, "") + dSQL = outputSettingsLineExpr.ReplaceAllString(dSQL, "") + dSQL = marshalSettingsLineExpr.ReplaceAllString(dSQL, "") + dSQL = unmarshalSettingsLineExpr.ReplaceAllString(dSQL, "") + dSQL = formatSettingsLineExpr.ReplaceAllString(dSQL, "") + dSQL = dateFormatSettingsLineExpr.ReplaceAllString(dSQL, "") + dSQL = caseFormatSettingsLineExpr.ReplaceAllString(dSQL, "") + return dSQL +} + +func removeHashImportDirectives(dSQL string) string { + return hashImportLineExpr.ReplaceAllString(dSQL, "") +} + +func qualifyTypeWithPackage(typeName, pkg string) string { + typeName = strings.TrimSpace(typeName) + pkg = strings.TrimSpace(pkg) + if typeName == "" || pkg == "" { + return typeName + } + + prefix := "" + base := typeName + for { + switch { + case strings.HasPrefix(base, "[]"): + prefix += "[]" + base = strings.TrimPrefix(base, "[]") + case strings.HasPrefix(base, "*"): + prefix += "*" + base = strings.TrimPrefix(base, "*") + default: + goto done + } + } +done: + if base == "" { + return typeName + } + if strings.Contains(base, ".") || strings.Contains(base, "/") || strings.Contains(base, "[") { + return typeName + } + return prefix + pkg + "." + base +} + func (r *Resource) expandSQL(viewlet *Viewlet) (*sqlx.SQL, error) { types := viewlet.Resource.Resource.TypeRegistry() resourceState := viewlet.Resource.State @@ -464,9 +823,11 @@ func (r *Resource) expandSQL(viewlet *Viewlet) (*sqlx.SQL, error) { func (r *Resource) ensureViewParametersSchema(ctx context.Context, setType func(ctx context.Context, setType *Viewlet) error) error { viewParameters := r.State.FilterByKind(state.KindView) for _, viewParameter := range viewParameters { - if viewParameter.Schema != nil && viewParameter.Schema.Type() != nil { - continue - } + //WE DO NOT NEEDED IT + //if viewParameter.Schema != nil && viewParameter.Schema.Type() != nil { + // fmt.Printf("skipping view %v %v\n", viewParameter.Name, viewParameter.Schema) + // //continue + //} if viewParameter.In.Name == "" { //default root schema continue } @@ -721,7 +1082,7 @@ func (r *Resource) updatedObject(loadType func(typeName string) (reflect.Type, e schema := parameter.OutputSchema() wType := schema.Type() if wType == nil { - return fmt.Errorf("failed to get parameter auxiliary type: %s, %w", parameter.Name, schema.Name) + return fmt.Errorf("failed to get parameter auxiliary type: %s, %s", parameter.Name, schema.Name) } auxiliaryState := inference.State{} if err := r.extractState(loadType, wType, &auxiliaryState); err != nil { diff --git a/internal/translator/resource_settings_test.go b/internal/translator/resource_settings_test.go new file mode 100644 index 000000000..32de4d39b --- /dev/null +++ b/internal/translator/resource_settings_test.go @@ -0,0 +1,81 @@ +package translator + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/viant/datly/cmd/options" +) + +func TestResource_extractRuleSetting_RouteDirectiveOverridesHeader(t *testing.T) { + resource := &Resource{Rule: NewRule(), rule: &options.Rule{}} + dSQL := "/* {\"URI\":\"/v1/api/legacy\",\"Method\":\"GET\"} */\n" + + "#settings($_ = $route('/v1/api/orders', 'POST', 'PATCH'))\n" + + "#settings($_ = $marshal('application/json','pkg.OrderJSON'))\n" + + "#settings($_ = $unmarshal('application/json','pkg.OrderIn'))\n" + + "#settings($_ = $unmarshal('application/xml','pkg.OrderXMLIn'))\n" + + "#settings($_ = $format('tabular_json'))\n" + + "#settings($_ = $date_format('2006-01-02'))\n" + + "#settings($_ = $case_format('lc'))\n" + + "SELECT 1" + + err := resource.extractRuleSetting(&dSQL) + require.NoError(t, err) + assert.Equal(t, "/v1/api/orders", resource.Rule.URI) + assert.Equal(t, "POST,PATCH", resource.Rule.Method) + assert.Equal(t, "pkg.OrderJSON", resource.Rule.JSONMarshalType) + assert.Equal(t, "pkg.OrderIn", resource.Rule.JSONUnmarshalType) + assert.Equal(t, "pkg.OrderXMLIn", resource.Rule.XMLUnmarshalType) + assert.Equal(t, "tabular", resource.Rule.DataFormat) + assert.Equal(t, "2006-01-02", resource.Rule.Route.Content.DateFormat) + assert.Equal(t, "lc", string(resource.Rule.Route.Output.CaseFormat)) + assert.NotContains(t, dSQL, "$route(") + assert.NotContains(t, dSQL, "$marshal(") + assert.NotContains(t, dSQL, "$unmarshal(") + assert.NotContains(t, dSQL, "$format(") + assert.NotContains(t, dSQL, "$date_format(") + assert.NotContains(t, dSQL, "$case_format(") +} + +func TestResource_extractRuleSetting_InvalidRouteDirective(t *testing.T) { + resource := &Resource{Rule: NewRule(), rule: &options.Rule{}} + dSQL := "#settings($_ = $route('/v1/api/orders', 'GOT'))\nSELECT 1" + + err := resource.extractRuleSetting(&dSQL) + require.Error(t, err) + assert.Contains(t, err.Error(), "unsupported method") +} + +func TestResource_extractRuleSetting_InvalidCaseFormatDirective(t *testing.T) { + resource := &Resource{Rule: NewRule(), rule: &options.Rule{}} + dSQL := "#settings($_ = $case_format('unknown'))\nSELECT 1" + + err := resource.extractRuleSetting(&dSQL) + require.Error(t, err) + assert.Contains(t, err.Error(), "unsupported case format") +} + +func TestResource_extractRuleSetting_PackageQualifiesTypes(t *testing.T) { + resource := &Resource{Rule: NewRule(), rule: &options.Rule{}} + dSQL := "#package('github.vianttech.com/viant/handson/pkg/platform/acl/auth')\n" + + "#settings($_ = $handler('Handler'))\n" + + "#settings($_ = $input('Input'))\n" + + "#settings($_ = $output('Output'))\n" + + "#settings($_ = $marshal('application/json','JSONOut'))\n" + + "#settings($_ = $unmarshal('application/json','JSONIn'))\n" + + "SELECT 1" + + err := resource.extractRuleSetting(&dSQL) + require.NoError(t, err) + assert.Equal(t, "github.vianttech.com/viant/handson/pkg/platform/acl/auth", resource.Rule.Package) + assert.Equal(t, "github.vianttech.com/viant/handson/pkg/platform/acl/auth.Handler", resource.Rule.Type) + assert.Equal(t, "github.vianttech.com/viant/handson/pkg/platform/acl/auth.Input", resource.Rule.InputType) + assert.Equal(t, "github.vianttech.com/viant/handson/pkg/platform/acl/auth.Output", resource.Rule.OutputType) + assert.Equal(t, "github.vianttech.com/viant/handson/pkg/platform/acl/auth.JSONOut", resource.Rule.JSONMarshalType) + assert.Equal(t, "github.vianttech.com/viant/handson/pkg/platform/acl/auth.JSONIn", resource.Rule.JSONUnmarshalType) + assert.NotContains(t, dSQL, "#package(") + assert.NotContains(t, dSQL, "$handler(") + assert.NotContains(t, dSQL, "$input(") + assert.NotContains(t, dSQL, "$output(") +} diff --git a/internal/translator/rule.go b/internal/translator/rule.go index 98ba973a2..b2cd35c96 100644 --- a/internal/translator/rule.go +++ b/internal/translator/rule.go @@ -67,6 +67,7 @@ type ( IsGeneratation bool XMLUnmarshalType string `json:",omitempty"` JSONUnmarshalType string `json:",omitempty"` + JSONMarshalType string `json:",omitempty"` OutputParameter *inference.Parameter } @@ -132,6 +133,7 @@ func (r *Rule) DSQLSetting() interface{} { DocURLs []string `json:",omitempty"` Internal bool `json:",omitempty"` JSONUnmarshalType string `json:",omitempty"` + JSONMarshalType string `json:",omitempty"` Connector string `json:",omitempty"` contract.ModelContextProtocol contract.Meta @@ -148,6 +150,7 @@ func (r *Rule) DSQLSetting() interface{} { DocURLs: r.DocURLs, Internal: r.Internal, JSONUnmarshalType: r.JSONUnmarshalType, + JSONMarshalType: r.JSONMarshalType, Connector: r.Connector, ModelContextProtocol: r.ModelContextProtocol, Meta: r.Meta, @@ -190,7 +193,7 @@ func (r *Resource) initRule(ctx context.Context, fs afs.Service, dSQL *string) e rule := r.Rule rule.applyDefaults() if err := r.loadData(ctx, fs, rule.ConstURL, &rule.Const); err != nil { - r.messages.AddWarning(r.rule.RuleName(), "const", fmt.Sprintf("failed to load constant : %v %w", rule.ConstURL, err)) + r.messages.AddWarning(r.rule.RuleName(), "const", fmt.Sprintf("failed to load constant : %v %v", rule.ConstURL, err)) } r.State.AppendConst(rule.Const) return r.loadDocumentation(ctx, fs, rule) @@ -321,7 +324,9 @@ func (r *Rule) applyDefaults() { if r.XMLUnmarshalType != "" { r.Route.Content.Marshaller.XML.TypeName = r.XMLUnmarshalType } - if r.JSONUnmarshalType != "" { + if r.JSONMarshalType != "" { + r.Route.Content.Marshaller.JSON.TypeName = r.JSONMarshalType + } else if r.JSONUnmarshalType != "" { r.Route.Content.Marshaller.JSON.TypeName = r.JSONUnmarshalType } } diff --git a/internal/translator/service.go b/internal/translator/service.go index 66c9ccabb..f383b9a37 100644 --- a/internal/translator/service.go +++ b/internal/translator/service.go @@ -4,6 +4,12 @@ import ( "context" "database/sql" "fmt" + "net/http" + spath "path" + "reflect" + "strings" + "time" + "github.com/viant/afs" "github.com/viant/afs/file" "github.com/viant/afs/url" @@ -14,7 +20,7 @@ import ( "github.com/viant/datly/internal/plugin" "github.com/viant/datly/internal/setter" "github.com/viant/datly/internal/translator/parser" - signature "github.com/viant/datly/repository/contract/signature" + "github.com/viant/datly/repository/contract/signature" "github.com/viant/datly/repository/path" "github.com/viant/datly/service" "github.com/viant/datly/shared" @@ -27,11 +33,6 @@ import ( "github.com/viant/xreflect" "golang.org/x/mod/modfile" "gopkg.in/yaml.v3" - "net/http" - spath "path" - "reflect" - "strings" - "time" ) type Service struct { @@ -116,9 +117,12 @@ func (s *Service) discoverComponentContract(ctx context.Context, resource *Resou return nil, err } } - location.Name = strings.ReplaceAll(location.Name, "..", "[]") - location.Name = strings.ReplaceAll(location.Name, ".", "/") - method, URI := shared.ExtractPath(location.Name) + locationName := strings.TrimSpace(location.Name) + if !strings.Contains(locationName, "/") { + locationName = strings.ReplaceAll(locationName, "..", "[]") + locationName = strings.ReplaceAll(locationName, ".", "/") + } + method, URI := shared.ExtractPath(locationName) return s.signature.Signature(method, URI) } @@ -328,6 +332,15 @@ func (s *Service) persistRouterRule(ctx context.Context, resource *Resource, ser } route.Component.Meta = resource.Rule.Meta + if route.Component.Meta.DescriptionURI != "" { + URL := url.Join(baseRuleURL, route.Component.Meta.DescriptionURI) + description, err := s.fs.DownloadWithURL(ctx, URL) + if err != nil { + return fmt.Errorf("failed to download meta description: %v %w", URL, err) + } + route.Component.Meta.Description = string(description) + } + route.ModelContextProtocol = resource.Rule.ModelContextProtocol if route.Handler != nil { if route.Component.Output.Type.Schema == nil { @@ -361,7 +374,10 @@ func (s *Service) persistRouterRule(ctx context.Context, resource *Resource, ser if resource.Rule.XMLUnmarshalType != "" { route.Content.Marshaller.XML.TypeName = resource.Rule.XMLUnmarshalType } - if resource.Rule.JSONUnmarshalType != "" { + // JSON marshaller/unmarshaller customization: prefer MarshalType if provided, fallback to UnmarshalType. + if resource.Rule.JSONMarshalType != "" { + route.Content.Marshaller.JSON.TypeName = resource.Rule.JSONMarshalType + } else if resource.Rule.JSONUnmarshalType != "" { route.Content.Marshaller.JSON.TypeName = resource.Rule.JSONUnmarshalType } route.Component.Output.DataFormat = resource.Rule.DataFormat @@ -435,7 +451,7 @@ func (s *Service) persistDocumentation(ctx context.Context, resource *Resource, } func extractTypeNameWithPackage(outputName string) (string, string) { - if index := strings.Index(outputName, "."); index != -1 { + if index := strings.LastIndex(outputName, "."); index != -1 { return outputName[:index], outputName[index+1:] } return outputName, "" @@ -449,6 +465,9 @@ func (s *Service) adjustView(viewlet *Viewlet, resource *Resource, mode view.Mod } if viewlet.TypeDefinition != nil { if viewlet.TypeDefinition.Cardinality == state.Many { + if viewlet.View.View.Schema == nil { + viewlet.View.View.Schema = &state.Schema{} + } viewlet.View.View.Schema.Cardinality = viewlet.TypeDefinition.Cardinality } viewlet.TypeDefinition.Cardinality = "" @@ -473,7 +492,12 @@ func (s *Service) adjustView(viewlet *Viewlet, resource *Resource, mode view.Mod if len(resource.Declarations.QuerySelectors) > 0 { for key, state := range resource.Declarations.QuerySelectors { - return fmt.Errorf("unknown query selector view %v, %v", key, state[0].Name) + switch strings.ToLower(state[0].Name) { + case "limit", "page", "offset", "fields", "orderby", "criteria": + default: + return fmt.Errorf("unknown query selector view %v, %v", key, state[0].In.Name) + + } } } @@ -561,7 +585,8 @@ func (s *Service) buildQueryViewletType(ctx context.Context, viewlet *Viewlet) e func (s *Service) buildViewletType(ctx context.Context, db *sql.DB, viewlet *Viewlet) (err error) { shared.EnsureArgs(viewlet.Expanded.Query, &viewlet.Expanded.Args) - if viewlet.Spec, err = inference.NewSpec(ctx, db, &s.Repository.Messages, viewlet.Table.Name, viewlet.ColumnConfig, viewlet.Expanded.Query, viewlet.Expanded.Args...); err != nil { + viewlet.Spec, err = inference.NewSpec(ctx, db, &s.Repository.Messages, viewlet.Table.Name, viewlet.ColumnConfig, viewlet.Expanded.Query, viewlet.Expanded.Args...) + if err != nil { return fmt.Errorf("failed to create spec for %v, %w", viewlet.Name, err) } diff --git a/internal/translator/view.go b/internal/translator/view.go index ab94c500e..fc0096ca8 100644 --- a/internal/translator/view.go +++ b/internal/translator/view.go @@ -2,15 +2,17 @@ package translator import ( "fmt" + "github.com/viant/datly/internal/asset" "github.com/viant/datly/internal/inference" "github.com/viant/datly/internal/setter" "github.com/viant/datly/internal/translator/parser" + "path" + "github.com/viant/datly/view" "github.com/viant/datly/view/state" "github.com/viant/tagly/format/text" - "path" ) type ( @@ -212,7 +214,8 @@ func (v *View) buildSelector(namespace *Viewlet, rule *Rule) { selector.PageParameter = ¶meter.Parameter selector.Constraints.Page = &enabled } - delete(namespace.Resource.Declarations.QuerySelectors, namespace.Name) + + //delete(namespace.Resource.Declarations.QuerySelectors, namespace.Name) } } diff --git a/internal/translator/viewlets.go b/internal/translator/viewlets.go index eb1d24b7b..ddeb349aa 100644 --- a/internal/translator/viewlets.go +++ b/internal/translator/viewlets.go @@ -76,7 +76,7 @@ func (n *Viewlets) Init(ctx context.Context, aQuery *query.Select, resource *Res if err := n.Each(func(viewlet *Viewlet) error { n.ensureConnector(viewlet, rootConnector) if err := initFn(ctx, viewlet); err != nil { - return fmt.Errorf("failed to init viewlet: %ns, %w", viewlet.Name, err) + return fmt.Errorf("failed to init viewlet: %s, %w", viewlet.Name, err) } return nil }); err != nil { @@ -138,6 +138,9 @@ func (n *Viewlets) addRelations(query *query.Select) error { parentNs := inference.ParentAlias(join) parentViewlet := n.Lookup(parentNs) + if parentViewlet == nil { + return fmt.Errorf("parent viewlet %v doesn't exist", parentNs) + } relation.Spec.Parent = parentViewlet.Spec cardinality := state.Many if inference.IsToOne(join) || relation.OutputSettings.IsToOne() { diff --git a/logger/adapter.go b/logger/adapter.go index d3777ff0c..e060cdd62 100644 --- a/logger/adapter.go +++ b/logger/adapter.go @@ -88,7 +88,7 @@ func (l *Adapter) Inherit(adapter *Adapter) { func (l *Adapter) LogDatabaseErr(SQL string, err error, args ...interface{}) { SQL = shared.ExpandSQL(SQL, args) - fmt.Printf(fmt.Sprintf("error occured while executing SQL: %v, SQL: %v, params: %v\n", err, strings.ReplaceAll(SQL, "\n", "\\n"), args)) + fmt.Printf("error occured while executing SQL: %v, SQL: %v, params: %v\n", err, strings.ReplaceAll(SQL, "\n", "\\n"), args) } func NewLogger(name string, logger Logger) *Adapter { diff --git a/mcp/server.go b/mcp/server.go index 2c9cd306e..b3e1c3e45 100644 --- a/mcp/server.go +++ b/mcp/server.go @@ -3,6 +3,7 @@ package mcp import ( "context" "fmt" + "github.com/viant/afs" "github.com/viant/afs/http" "github.com/viant/afs/url" @@ -15,18 +16,20 @@ import ( "github.com/viant/mcp/client/auth/transport" authserver "github.com/viant/mcp/server/auth" - serverproto "github.com/viant/mcp-protocol/server" - "github.com/viant/scy/auth/flow" "os" "path" + serverproto "github.com/viant/mcp-protocol/server" + "github.com/viant/scy/auth/flow" + + "reflect" + "strconv" + "strings" + "github.com/viant/mcp/server" "github.com/viant/scy" "github.com/viant/scy/cred" "golang.org/x/oauth2" - "reflect" - "strconv" - "strings" ) type Server struct { @@ -40,7 +43,7 @@ func (s *Server) init() error { var newImplementer = extension.New(s.registry) var options = []server.Option{ server.WithNewHandler(newImplementer), - server.WithImplementation(schema.Implementation{"Datly", "0.1"}), + server.WithImplementation(schema.Implementation{Name: "Datly", Version: "0.1"}), } issuerURL := s.config.IssuerURL var oauth2Config *oauth2.Config @@ -54,6 +57,7 @@ func (s *Server) init() error { } if issuerURL == "" && oauth2Config != nil { issuerURL, _ = url.Base(oauth2Config.Endpoint.AuthURL, http.SecureScheme) + s.config.IssuerURL = issuerURL } } authPolicy := &authorization.Policy{ diff --git a/repository/component.go b/repository/component.go index 31faed6cf..179ff7a9a 100644 --- a/repository/component.go +++ b/repository/component.go @@ -4,6 +4,10 @@ import ( "context" "embed" "fmt" + "net/http" + "reflect" + "strings" + "github.com/francoispqt/gojay" "github.com/viant/afs" "github.com/viant/datly/gateway/router/marshal" @@ -11,9 +15,10 @@ import ( "github.com/viant/datly/gateway/router/marshal/json" "github.com/viant/datly/internal/setter" "github.com/viant/datly/repository/async" - "github.com/viant/datly/repository/content" + content "github.com/viant/datly/repository/content" "github.com/viant/datly/repository/contract" "github.com/viant/datly/repository/handler" + "github.com/viant/datly/repository/shape/typectx" "github.com/viant/datly/repository/version" "github.com/viant/datly/service" "github.com/viant/datly/shared" @@ -29,9 +34,6 @@ import ( xhandler "github.com/viant/xdatly/handler" hstate "github.com/viant/xdatly/handler/state" "github.com/viant/xreflect" - "net/http" - "reflect" - "strings" ) // Component represents abstract API view/handler based component @@ -46,6 +48,7 @@ type ( View *view.View `json:",omitempty"` NamespacedView *view.NamespacedView Handler *handler.Handler `json:",omitempty"` + TypeContext *typectx.Context `json:",omitempty" yaml:",omitempty"` indexedView view.NamedViews SourceURL string @@ -169,6 +172,109 @@ func (c *Component) initView(ctx context.Context, resource *view.Resource) error if err := c.View.Init(ctx, resource); err != nil { return err } + // For read components (GET), expose and enable offset/limit/fields/page/orderBy for each namespaced view. + if strings.EqualFold(c.Path.Method, http.MethodGet) { + // Helper to enable limit/offset for a view with namespace prefix (if any) + ensureSelectors := func(v *view.View, nsPrefix string) { + if v == nil { + return + } + if v.Selector == nil { + v.Selector = &view.Config{} + } + if v.Selector.Constraints == nil { + v.Selector.Constraints = &view.Constraints{} + } + // Enable constraints + v.Selector.Constraints.Limit = true + v.Selector.Constraints.Offset = true + v.Selector.Constraints.Projection = true + v.Selector.Constraints.OrderBy = true + + // Limit param + if v.Selector.LimitParameter == nil { + p := *view.QueryStateParameters.LimitParameter + p.Description = view.Description(view.LimitQuery, v.Name) + if nsPrefix != "" { + p.In = state.NewQueryLocation(nsPrefix + view.LimitQuery) + } + v.Selector.LimitParameter = &p + } else if v.Selector.LimitParameter.Description == "" { + v.Selector.LimitParameter.Description = view.Description(view.LimitQuery, v.Name) + } + + // Offset param + if v.Selector.OffsetParameter == nil { + p := *view.QueryStateParameters.OffsetParameter + p.Description = view.Description(view.OffsetQuery, v.Name) + if nsPrefix != "" { + p.In = state.NewQueryLocation(nsPrefix + view.OffsetQuery) + } + v.Selector.OffsetParameter = &p + } else if v.Selector.OffsetParameter.Description == "" { + v.Selector.OffsetParameter.Description = view.Description(view.OffsetQuery, v.Name) + } + + // Fields param (controls which fields are included) + if v.Selector.FieldsParameter == nil { + p := *view.QueryStateParameters.FieldsParameter + p.Description = view.Description(view.FieldsQuery, v.Name) + if nsPrefix != "" { + p.In = state.NewQueryLocation(nsPrefix + view.FieldsQuery) + } + v.Selector.FieldsParameter = &p + } else if v.Selector.FieldsParameter.Description == "" { + v.Selector.FieldsParameter.Description = view.Description(view.FieldsQuery, v.Name) + } + + // Page param (paging interface on top of limit/offset) + if v.Selector.PageParameter == nil { + p := *view.QueryStateParameters.PageParameter + p.Description = view.Description(view.PageQuery, v.Name) + if nsPrefix != "" { + p.In = state.NewQueryLocation(nsPrefix + view.PageQuery) + } + v.Selector.PageParameter = &p + } else if v.Selector.PageParameter.Description == "" { + v.Selector.PageParameter.Description = view.Description(view.PageQuery, v.Name) + } + + // OrderBy param + if v.Selector.OrderByParameter == nil { + p := *view.QueryStateParameters.OrderByParameter + p.Description = view.Description(view.OrderByQuery, v.Name) + if nsPrefix != "" { + p.In = state.NewQueryLocation(nsPrefix + view.OrderByQuery) + } + v.Selector.OrderByParameter = &p + } else if v.Selector.OrderByParameter.Description == "" { + v.Selector.OrderByParameter.Description = view.Description(view.OrderByQuery, v.Name) + } + } + + // Root view + nsPrefix := "" + if c.View.Selector != nil && c.View.Selector.Namespace != "" { + nsPrefix = c.View.Selector.Namespace + } + ensureSelectors(c.View, nsPrefix) + + // All related views via NamespacedView + if c.NamespacedView != nil { + for _, nsView := range c.NamespacedView.Views { + v := nsView.View + // Determine ns prefix from NamespacedView (prefer non-empty namespace if present) + pfx := "" + for _, ns := range nsView.Namespaces { + if ns != "" { + pfx = ns + break + } + } + ensureSelectors(v, pfx) + } + } + } holder := "" if c.Contract.Output.Type.Parameters != nil { if rootHolder := c.Contract.Output.Type.Parameters.LookupByLocation(state.KindOutput, "view"); rootHolder != nil { @@ -260,27 +366,143 @@ func (c *Component) IOConfig() *config.IOConfig { } func (c *Component) UnmarshalFunc(request *http.Request) shared.Unmarshal { - contentType := request.Header.Get(content.HeaderContentType) - setter.SetStringIfEmpty(&contentType, request.Header.Get(strings.ToLower(content.HeaderContentType))) + // Delegate to options-based variant for symmetry and centralization. + return c.UnmarshalFor(WithUnmarshalRequest(request)) +} + +// UnmarshalOption configures unmarshal behavior for Component.UnmarshalFor. +type UnmarshalOption func(*unmarshalOptions) + +type unmarshalOptions struct { + request *http.Request + contentType string + interceptors json.UnmarshalerInterceptors +} + +// WithUnmarshalRequest supplies an http request for content-type detection and transforms. +func WithUnmarshalRequest(r *http.Request) UnmarshalOption { + return func(o *unmarshalOptions) { o.request = r } +} + +// WithContentType overrides the detected content type. +func WithContentType(ct string) UnmarshalOption { + return func(o *unmarshalOptions) { o.contentType = ct } +} + +// WithUnmarshalInterceptors adds/overrides JSON path interceptors. +func WithUnmarshalInterceptors(m json.UnmarshalerInterceptors) UnmarshalOption { + return func(o *unmarshalOptions) { + if o.interceptors == nil { + o.interceptors = json.UnmarshalerInterceptors{} + } + for k, v := range m { + o.interceptors[k] = v + } + } +} + +// UnmarshalFor returns a request-scoped unmarshal function applying content-type detection and transforms. +func (c *Component) UnmarshalFor(opts ...UnmarshalOption) shared.Unmarshal { + options := &unmarshalOptions{} + for _, opt := range opts { + if opt != nil { + opt(options) + } + } + + // Resolve content type if request present + contentType := options.contentType + if contentType == "" && options.request != nil { + contentType = options.request.Header.Get(content.HeaderContentType) + setter.SetStringIfEmpty(&contentType, options.request.Header.Get(strings.ToLower(content.HeaderContentType))) + } + switch contentType { case content.XMLContentType: return c.Content.Marshaller.XML.Unmarshal case content.CSVContentType: return c.Content.CSV.Unmarshal - default: - switch c.Output.DataFormat { - case content.XMLFormat: - return c.Content.Marshaller.XML.Unmarshal + } + // Fallback to data format preference when no content type or not matched + if c.Output.DataFormat == content.XMLFormat { + return c.Content.Marshaller.XML.Unmarshal + } + + // Build JSON path interceptors from component transforms and any user-provided ones + interceptors := options.interceptors + if interceptors == nil { + interceptors = json.UnmarshalerInterceptors{} + } + if options.request != nil { + for _, transform := range c.UnmarshallerInterceptors() { + interceptors[transform.Path] = c.transformFn(options.request, transform) + } + } + + req := options.request // capture for closure + return func(data []byte, dest interface{}) error { + if len(interceptors) > 0 || req != nil { + return c.Content.Marshaller.JSON.JsonMarshaller.Unmarshal(data, dest, interceptors, req) + } + return c.Content.Marshaller.JSON.JsonMarshaller.Unmarshal(data, dest) + } +} + +// MarshalOption configures marshal behavior for Component.MarshalFunc. +type MarshalOption func(*marshalOptions) + +type marshalOptions struct { + request *http.Request + format string + field string + filters []*json.FilterEntry +} + +// WithRequest supplies an http request for deriving format and state-based exclusions. +func WithRequest(r *http.Request) MarshalOption { return func(o *marshalOptions) { o.request = r } } + +// WithFormat overrides the output format (e.g. content.JSONFormat, content.CSVFormat, etc.). +func WithFormat(format string) MarshalOption { return func(o *marshalOptions) { o.format = format } } + +// WithField overrides the field used by tabular JSON embedding. +func WithField(field string) MarshalOption { return func(o *marshalOptions) { o.field = field } } + +// WithFilters sets explicit JSON field filters (exclusion-based projection). +func WithFilters(filters []*json.FilterEntry) MarshalOption { + return func(o *marshalOptions) { o.filters = filters } +} + +// MarshalFunc returns a request-scoped marshaller closure applying options like format and exclusions. +// If no format is specified, it defaults to JSON for non-reader services and derives from request for readers. +func (c *Component) MarshalFunc(opts ...MarshalOption) shared.Marshal { + options := &marshalOptions{} + for _, opt := range opts { + if opt != nil { + opt(options) } } - jsonPathInterceptor := json.UnmarshalerInterceptors{} - unmarshallerInterceptors := c.UnmarshallerInterceptors() - for i := range unmarshallerInterceptors { - transform := unmarshallerInterceptors[i] - jsonPathInterceptor[transform.Path] = c.transformFn(request, transform) + + // Resolve format + format := options.format + if format == "" { + if options.request != nil && c.Service == service.TypeReader { + format = c.Output.Format(options.request.URL.Query()) + } else { + format = content.JSONFormat + } + } + + // Resolve field (used for tabular JSON embedding) + field := options.field + if field == "" { + field = c.Output.Field() } - return func(bytes []byte, i interface{}) error { - return c.Content.Marshaller.JSON.JsonMarshaller.Unmarshal(bytes, i, jsonPathInterceptor, request) + + // Resolve filters (explicit only) + filters := options.filters + + return func(src interface{}) ([]byte, error) { + return c.Content.Marshal(format, field, src, filters) } } @@ -424,6 +646,9 @@ func WithContract(inputType, outputType reflect.Type, embedFs *embed.FS, viewOpt aCache := &view.Cache{Reference: shared.Reference{Ref: aView.Cache}} viewOptions = append(viewOptions, view.WithCache(aCache)) } + if aView.Limit != nil { + viewOptions = append(viewOptions, view.WithLimit(aView.Limit)) + } if aTag.View.PublishParent { viewOptions = append(viewOptions, view.WithViewPublishParent(aTag.View.PublishParent)) @@ -441,6 +666,10 @@ func WithContract(inputType, outputType reflect.Type, embedFs *embed.FS, viewOpt if aTag.View.Batch != 0 { viewOptions = append(viewOptions, view.WithBatchSize(aTag.View.Batch)) } + if aTag.View.Limit != nil { + viewOptions = append(viewOptions, view.WithLimit(aTag.View.Limit)) + } + if aTag.View.RelationalConcurrency != 0 { viewOptions = append(viewOptions, view.WithRelationalConcurrency(aTag.View.RelationalConcurrency)) } diff --git a/repository/components.go b/repository/components.go index c803bcd3f..a431095a3 100644 --- a/repository/components.go +++ b/repository/components.go @@ -13,6 +13,13 @@ import ( "github.com/viant/datly/internal/inference" "github.com/viant/datly/internal/translator/parser" "github.com/viant/datly/repository/codegen" + "github.com/viant/datly/repository/shape" + shapecolumn "github.com/viant/datly/repository/shape/column" + dqlparse "github.com/viant/datly/repository/shape/dql/parse" + shapeLoad "github.com/viant/datly/repository/shape/load" + shapePlan "github.com/viant/datly/repository/shape/plan" + shapeScan "github.com/viant/datly/repository/shape/scan" + "github.com/viant/datly/repository/shape/typectx" "github.com/viant/datly/repository/version" "github.com/viant/datly/utils/types" "github.com/viant/datly/view" @@ -24,6 +31,7 @@ import ( "gopkg.in/yaml.v3" "path" "reflect" + "strings" ) type Components struct { @@ -61,6 +69,9 @@ func (c *Components) Init(ctx context.Context) error { options = append(options, &view.Metrics{Method: c.Components[0].Method, Service: c.options.metrics}) } for _, component := range c.Components { + if c.options != nil && c.options.legacyTypeContext { + component.TypeContext = resolveComponentTypeContext(component) + } if len(component.with) > 0 { c.With = append(c.With, component.with...) } @@ -80,6 +91,9 @@ func (c *Components) Init(ctx context.Context) error { } c.ensureNamedViewType(ctx, embedFs, aComponent) + if err = c.mergeShapeViews(ctx, aComponent); err != nil { + return err + } if err = c.Resource.Init(ctx, options...); err != nil { return err @@ -106,6 +120,62 @@ func (c *Components) Init(ctx context.Context) error { return nil } +func (c *Components) mergeShapeViews(ctx context.Context, aComponent *Component) error { + if c.options == nil || !c.options.shapePipeline || aComponent == nil || aComponent.Output.Type.Schema == nil { + return nil + } + rType := c.ReflectType(aComponent.Output.Type.Schema) + if rType == nil { + return nil + } + engine := shape.New( + shape.WithScanner(shapeScan.New()), + shape.WithPlanner(shapePlan.New()), + shape.WithLoader(shapeLoad.New()), + shape.WithName(aComponent.Path.URI), + ) + source := zeroValue(rType) + if source == nil { + return nil + } + artifacts, err := engine.LoadViews(ctx, source) + if err != nil { + return fmt.Errorf("failed to load shape views for %s: %w", aComponent.Path.URI, err) + } + if artifacts == nil || artifacts.Resource == nil { + return nil + } + if c.Resource.FSEmbedder == nil && artifacts.Resource.FSEmbedder != nil { + c.Resource.FSEmbedder = artifacts.Resource.FSEmbedder + } + existing := c.Resource.Views.Index() + columnDetector := shapecolumn.New() + for _, candidate := range artifacts.Views { + if candidate == nil { + continue + } + if _, err = existing.Lookup(candidate.Name); err == nil { + continue + } + if candidate.Columns, err = columnDetector.Resolve(ctx, c.Resource, candidate); err != nil { + return fmt.Errorf("failed to resolve shape columns for %s: %w", candidate.Name, err) + } + c.Resource.Views = append(c.Resource.Views, candidate) + existing.Register(candidate) + } + return nil +} + +func zeroValue(rType reflect.Type) interface{} { + if rType == nil { + return nil + } + if rType.Kind() == reflect.Ptr { + return reflect.New(rType.Elem()).Interface() + } + return reflect.New(rType).Interface() +} + func (c *Components) ensureNamedViewType(ctx context.Context, embedFs *embed.FS, aComponent *Component) { inCodeGeneration := codegen.IsGeneratorContext(ctx) if rType := c.ReflectType(c.Components[0].Output.Type.Schema); rType != nil && !inCodeGeneration { @@ -236,6 +306,9 @@ func (c *Components) updateIOTypeDependencies(ctx context.Context, ioType *state aView = baseView } } + if aView.Schema == nil { + aView.Schema = parameterViewSchema(parameter) + } aView.Schema.SetType(parameter.Schema.Type()) } } @@ -371,7 +444,7 @@ func LoadComponents(ctx context.Context, URL string, opts ...Option) (*Component } } } - components, err := unmarshalComponent(data) + components, err := unmarshalComponent(data, options.legacyTypeContext) if err != nil { return nil, err } @@ -393,17 +466,47 @@ func LoadComponents(ctx context.Context, URL string, opts ...Option) (*Component return components, nil } -func unmarshalComponent(data []byte) (*Components, error) { +// LoadComponentsFromMap loads components directly from in-memory route/resource model. +// The input map is expected to follow the same shape as route YAML after unmarshalling. +func LoadComponentsFromMap(ctx context.Context, model map[string]any, opts ...Option) (*Components, error) { + if len(model) == 0 { + return nil, fmt.Errorf("components model was empty") + } + options := NewOptions(opts) + components, err := unmarshalComponentMap(model, options.legacyTypeContext) + if err != nil { + return nil, err + } + components.options = options + components.resources = options.resources + if components.Resource == nil { + return nil, fmt.Errorf("resources were empty") + } + if err = components.mergeResources(ctx); err != nil { + return nil, err + } + components.Resource.SetTypes(options.extensions.Types) + return components, nil +} + +func unmarshalComponent(data []byte, enableLegacyTypeContext bool) (*Components, error) { aMap := map[string]interface{}{} if err := yaml.Unmarshal(data, &aMap); err != nil { return nil, err } + return unmarshalComponentMap(aMap, enableLegacyTypeContext) +} + +func unmarshalComponentMap(aMap map[string]any, enableLegacyTypeContext bool) (*Components, error) { ensureComponents(aMap) components := &Components{} err := toolbox.DefaultConverter.AssignConverted(components, aMap) if err != nil { return nil, err } + if enableLegacyTypeContext { + applyLegacyTypeContext(aMap, components) + } return components, err } @@ -412,3 +515,141 @@ func ensureComponents(aMap map[string]interface{}) { aMap["Components"] = aMap["Routes"] } } + +func applyLegacyTypeContext(source map[string]any, components *Components) { + if len(components.Components) == 0 { + return + } + defaultTypeContext := asTypeContext(source["TypeContext"]) + items := asAnySlice(source["Components"]) + for i, component := range components.Components { + if component == nil { + continue + } + if component.TypeContext != nil { + continue + } + var resolved *typectx.Context + if i < len(items) { + if itemMap := asStringMap(items[i]); itemMap != nil { + resolved = asTypeContext(itemMap["TypeContext"]) + } + } + if resolved == nil { + resolved = defaultTypeContext + } + if resolved != nil { + component.TypeContext = cloneTypeContext(resolved) + } + } +} + +func asTypeContext(raw any) *typectx.Context { + mapped := asStringMap(raw) + if mapped == nil { + return nil + } + ret := &typectx.Context{ + DefaultPackage: asString(mapped["DefaultPackage"]), + } + for _, item := range asAnySlice(mapped["Imports"]) { + itemMap := asStringMap(item) + if itemMap == nil { + continue + } + pkg := asString(itemMap["Package"]) + if pkg == "" { + continue + } + ret.Imports = append(ret.Imports, typectx.Import{ + Alias: asString(itemMap["Alias"]), + Package: pkg, + }) + } + if ret.DefaultPackage == "" && len(ret.Imports) == 0 { + return nil + } + return ret +} + +func resolveComponentTypeContext(component *Component) *typectx.Context { + if component == nil { + return nil + } + if normalized := normalizeTypeContext(component.TypeContext); normalized != nil { + return normalized + } + if component.View == nil || component.View.Template == nil { + return nil + } + source := strings.TrimSpace(component.View.Template.Source) + if source == "" { + return nil + } + parsed, err := dqlparse.New().Parse(source) + if err != nil || parsed == nil { + return nil + } + return normalizeTypeContext(parsed.TypeContext) +} + +func normalizeTypeContext(input *typectx.Context) *typectx.Context { + if input == nil { + return nil + } + ret := &typectx.Context{ + DefaultPackage: strings.TrimSpace(input.DefaultPackage), + } + for _, item := range input.Imports { + pkg := strings.TrimSpace(item.Package) + if pkg == "" { + continue + } + ret.Imports = append(ret.Imports, typectx.Import{ + Alias: strings.TrimSpace(item.Alias), + Package: pkg, + }) + } + if ret.DefaultPackage == "" && len(ret.Imports) == 0 { + return nil + } + return ret +} + +func cloneTypeContext(input *typectx.Context) *typectx.Context { + return normalizeTypeContext(input) +} + +func asAnySlice(raw any) []any { + switch actual := raw.(type) { + case []any: + return actual + default: + return nil + } +} + +func asStringMap(raw any) map[string]any { + switch actual := raw.(type) { + case map[string]any: + return actual + case map[interface{}]interface{}: + result := make(map[string]any, len(actual)) + for k, v := range actual { + result[fmt.Sprint(k)] = v + } + return result + default: + return nil + } +} + +func asString(raw any) string { + if raw == nil { + return "" + } + if value, ok := raw.(string); ok { + return value + } + return fmt.Sprint(raw) +} diff --git a/repository/components_shape_test.go b/repository/components_shape_test.go new file mode 100644 index 000000000..d6ddfd00b --- /dev/null +++ b/repository/components_shape_test.go @@ -0,0 +1,62 @@ +package repository + +import ( + "context" + "reflect" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/viant/datly/repository/contract" + "github.com/viant/datly/view" + "github.com/viant/datly/view/state" +) + +type shapeTestRow struct { + ID int +} + +type shapeTestOutput struct { + Rows []shapeTestRow `view:"rows,table=REPORT" sql:"SELECT ID FROM REPORT"` +} + +func TestComponents_mergeShapeViews_Enabled(t *testing.T) { + resource := view.EmptyResource() + components := &Components{ + Resource: resource, + options: &Options{shapePipeline: true}, + } + + component := &Component{ + Path: contract.Path{URI: "/v1/api/report", Method: "GET"}, + Contract: contract.Contract{ + Output: contract.Output{Type: state.Type{Schema: state.NewSchema(reflect.TypeOf(&shapeTestOutput{}))}}, + }, + View: view.NewRefView("rows"), + } + + err := components.mergeShapeViews(context.Background(), component) + require.NoError(t, err) + require.Len(t, components.Resource.Views, 1) + assert.Equal(t, "rows", components.Resource.Views[0].Name) +} + +func TestComponents_mergeShapeViews_Disabled(t *testing.T) { + resource := view.EmptyResource() + components := &Components{ + Resource: resource, + options: &Options{shapePipeline: false}, + } + + component := &Component{ + Path: contract.Path{URI: "/v1/api/report", Method: "GET"}, + Contract: contract.Contract{ + Output: contract.Output{Type: state.Type{Schema: state.NewSchema(reflect.TypeOf(&shapeTestOutput{}))}}, + }, + View: view.NewRefView("rows"), + } + + err := components.mergeShapeViews(context.Background(), component) + require.NoError(t, err) + assert.Len(t, components.Resource.Views, 0) +} diff --git a/repository/components_typectx_test.go b/repository/components_typectx_test.go new file mode 100644 index 000000000..4239fd201 --- /dev/null +++ b/repository/components_typectx_test.go @@ -0,0 +1,164 @@ +package repository + +import ( + "testing" + + "github.com/stretchr/testify/require" + "github.com/viant/datly/repository/shape/typectx" + "github.com/viant/datly/view" +) + +func TestUnmarshalComponentMap_PropagatesTopLevelTypeContext(t *testing.T) { + model := map[string]any{ + "TypeContext": map[string]any{ + "DefaultPackage": "mdp/performance", + "Imports": []any{ + map[string]any{ + "Alias": "perf", + "Package": "github.com/acme/mdp/performance", + }, + }, + }, + "Components": []any{ + map[string]any{ + "URI": "/v1/api/sample", + "Method": "GET", + "View": map[string]any{ + "Ref": "sample", + }, + }, + }, + "Resource": map[string]any{ + "Views": []any{ + map[string]any{"Name": "sample"}, + }, + }, + } + components, err := unmarshalComponentMap(model, true) + require.NoError(t, err) + require.Len(t, components.Components, 1) + require.NotNil(t, components.Components[0].TypeContext) + require.Equal(t, "mdp/performance", components.Components[0].TypeContext.DefaultPackage) + require.Len(t, components.Components[0].TypeContext.Imports, 1) + require.Equal(t, "perf", components.Components[0].TypeContext.Imports[0].Alias) +} + +func TestUnmarshalComponentMap_PerComponentTypeContextOverridesTopLevel(t *testing.T) { + model := map[string]any{ + "TypeContext": map[string]any{ + "DefaultPackage": "top/level", + }, + "Components": []any{ + map[string]any{ + "URI": "/v1/api/sample", + "Method": "GET", + "View": map[string]any{ + "Ref": "sample", + }, + "TypeContext": map[string]any{ + "DefaultPackage": "component/level", + "Imports": []any{ + map[string]any{ + "Alias": "foo", + "Package": "github.com/acme/foo", + }, + }, + }, + }, + }, + "Resource": map[string]any{ + "Views": []any{ + map[string]any{"Name": "sample"}, + }, + }, + } + components, err := unmarshalComponentMap(model, true) + require.NoError(t, err) + require.Len(t, components.Components, 1) + require.NotNil(t, components.Components[0].TypeContext) + require.Equal(t, "component/level", components.Components[0].TypeContext.DefaultPackage) + require.Len(t, components.Components[0].TypeContext.Imports, 1) + require.Equal(t, "foo", components.Components[0].TypeContext.Imports[0].Alias) +} + +func TestUnmarshalComponentMap_NoTypeContextRemainsNil(t *testing.T) { + model := map[string]any{ + "Components": []any{ + map[string]any{ + "URI": "/v1/api/sample", + "Method": "GET", + "View": map[string]any{ + "Ref": "sample", + }, + }, + }, + "Resource": map[string]any{ + "Views": []any{ + map[string]any{"Name": "sample"}, + }, + }, + } + components, err := unmarshalComponentMap(model, true) + require.NoError(t, err) + require.Len(t, components.Components, 1) + require.Nil(t, components.Components[0].TypeContext) +} + +func TestUnmarshalComponentMap_TopLevelTypeContext_DisabledByFlag(t *testing.T) { + model := map[string]any{ + "TypeContext": map[string]any{ + "DefaultPackage": "mdp/performance", + }, + "Components": []any{ + map[string]any{ + "URI": "/v1/api/sample", + "Method": "GET", + "View": map[string]any{ + "Ref": "sample", + }, + }, + }, + "Resource": map[string]any{ + "Views": []any{ + map[string]any{"Name": "sample"}, + }, + }, + } + components, err := unmarshalComponentMap(model, false) + require.NoError(t, err) + require.Len(t, components.Components, 1) + require.Nil(t, components.Components[0].TypeContext) +} + +func TestResolveComponentTypeContext_FromTemplateSource(t *testing.T) { + component := &Component{ + View: &view.View{ + Template: view.NewTemplate(` +#set($_ = $package('mdp/performance')) +#set($_ = $import('perf', 'github.com/acme/mdp/performance')) +SELECT ID FROM REPORT r`), + }, + } + resolved := resolveComponentTypeContext(component) + require.NotNil(t, resolved) + require.Equal(t, "mdp/performance", resolved.DefaultPackage) + require.Len(t, resolved.Imports, 1) + require.Equal(t, "perf", resolved.Imports[0].Alias) +} + +func TestResolveComponentTypeContext_PrefersExisting(t *testing.T) { + component := &Component{ + TypeContext: &typectx.Context{ + DefaultPackage: " custom/pkg ", + Imports: []typectx.Import{ + {Alias: " a ", Package: " github.com/acme/a "}, + }, + }, + } + resolved := resolveComponentTypeContext(component) + require.NotNil(t, resolved) + require.Equal(t, "custom/pkg", resolved.DefaultPackage) + require.Len(t, resolved.Imports, 1) + require.Equal(t, "a", resolved.Imports[0].Alias) + require.Equal(t, "github.com/acme/a", resolved.Imports[0].Package) +} diff --git a/repository/contract/dispatcher.go b/repository/contract/dispatcher.go index 22afc3d26..6da3f9f72 100644 --- a/repository/contract/dispatcher.go +++ b/repository/contract/dispatcher.go @@ -2,6 +2,7 @@ package contract import ( "context" + "github.com/viant/xdatly/handler/logger" hstate "github.com/viant/xdatly/handler/state" "net/http" "net/url" @@ -16,6 +17,7 @@ type ( Header http.Header Form *hstate.Form Request *http.Request + Logger logger.Logger } //Option represents a dispatcher option Option func(o *Options) @@ -77,3 +79,10 @@ func WithRequest(request *http.Request) Option { o.Request = request } } + +// WithLogger adds path parameters +func WithLogger(loger logger.Logger) Option { + return func(o *Options) { + o.Logger = loger + } +} diff --git a/repository/contract/meta.go b/repository/contract/meta.go index 878128b5e..8b0cfa71a 100644 --- a/repository/contract/meta.go +++ b/repository/contract/meta.go @@ -7,8 +7,9 @@ import ( // MCP Model Configuration Protocol path integration type Meta struct { - Name string `json:",omitempty" yaml:"Name"` // name of the MCP - Description string `json:",omitempty" yaml:"Description"` // optional description for documentation purposes + Name string `json:",omitempty" yaml:"Name"` // name of the MCP + Description string `json:",omitempty" yaml:"Description"` // optional description for documentation purposes + DescriptionURI string `json:",omitempty" yaml:"DescriptionURI"` } type ModelContextProtocol struct { diff --git a/repository/contract/signature/service.go b/repository/contract/signature/service.go index d1f79dbc8..572d0c3c5 100644 --- a/repository/contract/signature/service.go +++ b/repository/contract/signature/service.go @@ -58,6 +58,9 @@ func (s *Service) init(ctx context.Context) error { func (s *Service) Signature(method, URI string) (*Signature, error) { URI = strings.ReplaceAll(URI, "[]", "..") matchable, err := s.matcher.MatchOne(method, URI) + if err != nil && !strings.HasPrefix(URI, "/") { + matchable, err = s.matcher.MatchOne(method, "/"+URI) + } if err != nil && s.APIPrefix != "" { //fallback to full URI matchable, err = s.matcher.MatchOne(method, s.buildURI(URI)) } @@ -106,17 +109,42 @@ func (s *Service) Signature(method, URI string) (*Signature, error) { } func (s *Service) buildURI(URI string) string { - APIPrefix := strings.Split(s.APIPrefix, "/") - URIs := strings.Split(URI, "/") - var suffix []string - for _, item := range URIs { - if item == ".." { - APIPrefix = APIPrefix[:len(APIPrefix)-1] + URI = strings.TrimSpace(URI) + if URI == "" { + return strings.TrimRight(s.APIPrefix, "/") + } + if strings.HasPrefix(URI, "/") { + return URI + } + + prefixParts := splitPathParts(s.APIPrefix) + uriParts := splitPathParts(URI) + for _, part := range uriParts { + switch part { + case ".", "": + continue + case "..": + if len(prefixParts) > 0 { + prefixParts = prefixParts[:len(prefixParts)-1] + } + default: + prefixParts = append(prefixParts, part) + } + } + return "/" + strings.Join(prefixParts, "/") +} + +func splitPathParts(input string) []string { + raw := strings.Split(input, "/") + result := make([]string, 0, len(raw)) + for _, item := range raw { + item = strings.TrimSpace(item) + if item == "" || item == "." { continue } - suffix = append(suffix, item) + result = append(result, item) } - return strings.Join(append(APIPrefix, suffix...), "/") + return result } func (s *Service) loadSignatures(ctx context.Context, URL string, isRoot bool) error { diff --git a/repository/handler/handler.go b/repository/handler/handler.go index 3d1c173d7..f8434d71a 100644 --- a/repository/handler/handler.go +++ b/repository/handler/handler.go @@ -7,7 +7,10 @@ import ( "github.com/viant/datly/view" "github.com/viant/datly/view/state" "github.com/viant/xdatly/handler" + "github.com/viant/xreflect" + "github.com/viant/xunsafe" "reflect" + "strings" ) var Type = reflect.TypeOf((*handler.Handler)(nil)).Elem() @@ -36,7 +39,9 @@ func (h *Handler) Init(ctx context.Context, resource *view.Resource) (err error) h.resource = resource aType, err = h.resource.TypeRegistry().Lookup(h.Type) if err != nil { - return fmt.Errorf("couldn't parse Handler type due to %w", err) + if aType = lookupByPackagePathAlias(h.resource.TypeRegistry().Lookup, h.Type); aType == nil { + return fmt.Errorf("couldn't parse Handler type due to %w", err) + } } } if aType.Kind() != reflect.Ptr { @@ -125,3 +130,34 @@ func NewHandler(handler handler.Handler) *Handler { rType := reflect.TypeOf(handler) return &Handler{Type: rType.Name(), _type: rType} } + +func lookupByPackagePathAlias(lookup xreflect.LookupType, typeName string) reflect.Type { + typeName = strings.TrimSpace(typeName) + index := strings.LastIndex(typeName, ".") + if index == -1 || index == len(typeName)-1 { + return nil + } + pkgPath := typeName[:index] + name := typeName[index+1:] + if !strings.Contains(pkgPath, "/") { + return nil + } + segments := strings.Split(pkgPath, "/") + var candidates []string + if len(segments) >= 2 { + candidates = append(candidates, strings.Join(segments[len(segments)-2:], "/")) + } + candidates = append(candidates, segments[len(segments)-1]) + for _, candidate := range candidates { + if candidate == "" { + continue + } + if rType, err := lookup(name, xreflect.WithPackage(candidate)); err == nil && rType != nil { + return rType + } + } + if rType := xunsafe.LookupType(pkgPath + "/" + name); rType != nil { + return rType + } + return nil +} diff --git a/repository/locator/async/locator.go b/repository/locator/async/locator.go index bc8141cde..7d047fe42 100644 --- a/repository/locator/async/locator.go +++ b/repository/locator/async/locator.go @@ -9,13 +9,14 @@ import ( "github.com/viant/xdatly/handler/async" "github.com/viant/xdatly/handler/exec" "github.com/viant/xdatly/handler/response" + "reflect" "strings" "time" ) type Locator struct{} -func (l *Locator) Value(ctx context.Context, name string) (interface{}, bool, error) { +func (l *Locator) Value(ctx context.Context, _ reflect.Type, name string) (interface{}, bool, error) { name = strings.ToLower(name) if name == keys.JobError { diff --git a/repository/locator/component/component.go b/repository/locator/component/component.go index 1db480628..ffb308c2e 100644 --- a/repository/locator/component/component.go +++ b/repository/locator/component/component.go @@ -2,28 +2,33 @@ package component import ( "context" + "errors" "fmt" + "net/http" + "net/url" + "reflect" + "github.com/viant/datly/repository/contract" "github.com/viant/datly/shared" "github.com/viant/datly/view/state" "github.com/viant/datly/view/state/kind" "github.com/viant/datly/view/state/kind/locator" + "github.com/viant/xdatly/handler/logger" "github.com/viant/xdatly/handler/response" hstate "github.com/viant/xdatly/handler/state" "github.com/viant/xunsafe" - "net/http" - "net/url" - "reflect" ) type componentLocator struct { - custom []interface{} - dispatch contract.Dispatcher - constants map[string]interface{} - path map[string]string - form *hstate.Form - query url.Values - header http.Header + custom []interface{} + dispatch contract.Dispatcher + constants map[string]interface{} + path map[string]string + form *hstate.Form + query url.Values + header http.Header + logger logger.Logger + getRequest func() (*http.Request, error) } @@ -31,7 +36,7 @@ func (l *componentLocator) Names() []string { return nil } -func (l *componentLocator) Value(ctx context.Context, name string) (interface{}, bool, error) { +func (l *componentLocator) Value(ctx context.Context, _ reflect.Type, name string) (interface{}, bool, error) { method, URI := shared.ExtractPath(name) request, err := l.getRequest() if err != nil { @@ -43,6 +48,7 @@ func (l *componentLocator) Value(ctx context.Context, name string) (interface{}, contract.WithPath(l.path), contract.WithQuery(l.query), contract.WithForm(form), + contract.WithLogger(l.logger), contract.WithHeader(l.header), ) err = updateErrWithResponseStatus(err, value) @@ -53,7 +59,7 @@ func updateErrWithResponseStatus(err error, response interface{}) error { var statusErr error responseStatus, ok := tryExtractResponseStatus(response) if ok && responseStatus.Status == "error" { - statusErr = fmt.Errorf(responseStatus.Message) + statusErr = errors.New(responseStatus.Message) } if statusErr != nil { @@ -102,6 +108,7 @@ func newComponentLocator(opts ...locator.Option) (kind.Locator, error) { dispatch: options.Dispatcher, constants: options.Constants, getRequest: options.GetRequest, + logger: options.Logger, form: options.Form, query: options.Query, header: options.Header, diff --git a/repository/locator/component/dispatcher/disptacher.go b/repository/locator/component/dispatcher/disptacher.go index 0ea2fbd3b..0c13a136f 100644 --- a/repository/locator/component/dispatcher/disptacher.go +++ b/repository/locator/component/dispatcher/disptacher.go @@ -47,6 +47,7 @@ func (d *Dispatcher) Dispatch(ctx context.Context, path *contract.Path, opts ... aSession := session.New(aComponent.View, session.WithLocatorOptions(options...), session.WithAuth(d.auth), session.WithRegistry(d.registry), + session.WithLogger(cOptions.Logger), session.WithComponent(aComponent), session.WithOperate(d.service.Operate)) ctx = aSession.Context(ctx, true) diff --git a/repository/locator/meta/locator.go b/repository/locator/meta/locator.go index 487a58be8..ba1b0c1af 100644 --- a/repository/locator/meta/locator.go +++ b/repository/locator/meta/locator.go @@ -7,13 +7,14 @@ import ( "github.com/viant/datly/view/state" "github.com/viant/datly/view/state/kind" "github.com/viant/datly/view/state/kind/locator" + "reflect" "strings" ) type metaLocator struct { } -func (l *metaLocator) Value(ctx context.Context, name string) (interface{}, bool, error) { +func (l *metaLocator) Value(ctx context.Context, _ reflect.Type, name string) (interface{}, bool, error) { value := ctx.Value(view.ContextKey) if value == nil { return nil, false, nil diff --git a/repository/locator/output/output.go b/repository/locator/output/output.go index c9368971a..ec6b3ed5f 100644 --- a/repository/locator/output/output.go +++ b/repository/locator/output/output.go @@ -3,6 +3,7 @@ package output import ( "context" "encoding/json" + "reflect" "strings" "github.com/viant/datly/repository/locator/output/keys" @@ -25,7 +26,7 @@ func (l *Locator) Names() []string { return nil } -func (l *Locator) Value(ctx context.Context, name string) (interface{}, bool, error) { +func (l *Locator) Value(ctx context.Context, _ reflect.Type, name string) (interface{}, bool, error) { aName := strings.ToLower(name) switch aName { case keys.ViewData: diff --git a/repository/logging/logging.go b/repository/logging/logging.go index 2383e2f11..850bef918 100644 --- a/repository/logging/logging.go +++ b/repository/logging/logging.go @@ -1,48 +1,94 @@ package logging import ( + "encoding/json" + "errors" "fmt" - "github.com/goccy/go-json" - "github.com/viant/xdatly/handler/exec" + "reflect" + "runtime/debug" "strconv" - "time" + + "github.com/viant/xdatly/handler/exec" ) func Log(config *Config, execContext *exec.Context) { - execContext.ElapsedMs = int(time.Since(execContext.StartTime).Milliseconds()) + snap := execContext.SnapshotForLogging() includeSQL := config.ShallIncludeSQL() if !includeSQL { - execContext.Metrics = execContext.Metrics.HideMetrics() + snap.Metrics = snap.Metrics.HideMetrics() } if config.IsAuditEnabled() { - data, _ := json.Marshal(execContext) - fmt.Println("[AUDIT] " + string(data)) + data := safeMarshal("EXECCONTEXT", snap) + fmt.Println("[AUDIT]", string(data)) } if config.IsTracingEnabled() { - trace := execContext.Trace + trace := snap.Trace rootSpan := trace.Spans[0] - spans := execContext.Metrics.ToSpans(&rootSpan.SpanID) - if execContext.Auth != nil { - if execContext.Auth.UserID != 0 { - rootSpan.Attributes["jwt.uid"] = strconv.Itoa(execContext.Auth.UserID) + spans := snap.Metrics.ToSpans(&rootSpan.SpanID) + if snap.Auth != nil { + if snap.Auth.UserID != 0 { + rootSpan.Attributes["jwt.uid"] = strconv.Itoa(snap.Auth.UserID) } - if execContext.Auth.Username != "" { - rootSpan.Attributes["jwt.username"] = execContext.Auth.Username + if snap.Auth.Username != "" { + rootSpan.Attributes["jwt.username"] = snap.Auth.Username } - if execContext.Auth.Email != "" { - rootSpan.Attributes["jwt.email"] = execContext.Auth.Email + if snap.Auth.Email != "" { + rootSpan.Attributes["jwt.email"] = snap.Auth.Email } - if execContext.Auth.Scope != "" { - rootSpan.Attributes["jwt.scope"] = execContext.Auth.Scope + if snap.Auth.Scope != "" { + rootSpan.Attributes["jwt.scope"] = snap.Auth.Scope } } trace.Append(spans...) - if execContext.Error != "" { - trace.Spans[0].SetStatus(fmt.Errorf(execContext.Error)) + if snap.Error != "" { + trace.Spans[0].SetStatus(errors.New(snap.Error)) } else { - trace.Spans[0].SetStatusFromHTTPCode(execContext.StatusCode) + trace.Spans[0].SetStatusFromHTTPCode(snap.StatusCode) + } + traceData := safeMarshal("TRACE", trace) + fmt.Println("[TRACE]", string(traceData)) + } +} + +func safeMarshal(label string, v any) []byte { + defer func() { + if r := recover(); r != nil { + fmt.Printf("[LOG-MARSHAL-PANIC] label=%s type=%T panic=%v\nSTACK:\n%s\n", label, v, r, debug.Stack()) + if execCtx, ok := v.(*exec.Context); ok { + findBadField(execCtx) + } } - traceData, _ := json.Marshal(trace) - fmt.Println("[TRACE] " + string(traceData)) + }() + data, err := json.Marshal(v) + if err != nil { + fmt.Printf("[LOG-MARSHAL-ERROR] label=%s type=%T err=%v\n", label, v, err) + return nil + } + return data +} + +func findBadField(execCtx *exec.Context) { + val := reflect.ValueOf(execCtx).Elem() + typ := val.Type() + for i := 0; i < val.NumField(); i++ { + field := val.Field(i) + fieldType := typ.Field(i) + fieldName := fieldType.Name + + // Skip unexported fields + if !field.CanInterface() { + continue + } + + func() { + defer func() { + if r := recover(); r != nil { + fmt.Printf("[BAD-FIELD-PANIC] %s (%s): %v\n", fieldName, field.Type(), r) + } + }() + if _, err := json.Marshal(field.Interface()); err != nil { + fmt.Printf("[BAD-FIELD-ERROR] %s (%s): %v\n", fieldName, field.Type(), err) + } + }() } } diff --git a/repository/logging/logging_test.go b/repository/logging/logging_test.go new file mode 100644 index 000000000..a9beb4293 --- /dev/null +++ b/repository/logging/logging_test.go @@ -0,0 +1,194 @@ +package logging + +import ( + "bytes" + "encoding/json" + "io" + "os" + "strings" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/viant/xdatly/handler/exec" +) + +// TestSafeMarshal_Success tests successful JSON marshaling +func TestSafeMarshal_Success(t *testing.T) { + type TestStruct struct { + Name string `json:"name"` + Value int `json:"value"` + } + + testData := TestStruct{ + Name: "test", + Value: 42, + } + + result := safeMarshal("TEST", testData) + assert.NotNil(t, result, "safeMarshal should return non-nil for valid data") + + var unmarshaled TestStruct + err := json.Unmarshal(result, &unmarshaled) + assert.NoError(t, err) + assert.Equal(t, testData, unmarshaled) +} + +// TestSafeMarshal_Error tests safeMarshal with a value that causes a marshaling error +func TestSafeMarshal_Error(t *testing.T) { + // Channel cannot be marshaled to JSON + ch := make(chan int) + result := safeMarshal("TEST", ch) + assert.Nil(t, result, "safeMarshal should return nil when marshaling fails") +} + +// TestSafeMarshal_Panic tests safeMarshal with a value that causes a panic +func TestSafeMarshal_Panic(t *testing.T) { + // Function cannot be marshaled and may cause panic + fn := func() {} + result := safeMarshal("TEST", fn) + assert.Nil(t, result, "safeMarshal should return nil when marshaling panics") +} + +// TestSafeMarshal_ExecContext tests safeMarshal with exec.Context +func TestSafeMarshal_ExecContext(t *testing.T) { + execCtx := exec.NewContext("GET", "/test", nil, "") + result := safeMarshal("EXECCONTEXT", execCtx) + + // Should either succeed (return non-nil) or fail gracefully (return nil) + // The important thing is it doesn't panic + if result != nil { + assert.NotEmpty(t, result) + } +} + +// TestSafeMarshal_NilValue tests safeMarshal with nil value +func TestSafeMarshal_NilValue(t *testing.T) { + result := safeMarshal("TEST", nil) + assert.NotNil(t, result) + assert.Equal(t, []byte("null"), result) +} + +// TestFindBadField_ValidExecContext tests findBadField with a valid exec.Context +func TestFindBadField_ValidExecContext(t *testing.T) { + // Capture stdout to check output + oldStdout := os.Stdout + r, w, _ := os.Pipe() + os.Stdout = w + + execCtx := exec.NewContext("GET", "/test", nil, "") + findBadField(execCtx) + + // Close write pipe and restore stdout + w.Close() + os.Stdout = oldStdout + + // Read captured output + var buf bytes.Buffer + io.Copy(&buf, r) + output := buf.String() + + // With a valid exec.Context, there should be no bad field errors + assert.NotContains(t, output, "[BAD-FIELD-ERROR]", "valid exec.Context should not have bad fields") + assert.NotContains(t, output, "[BAD-FIELD-PANIC]", "valid exec.Context should not panic on field marshaling") +} + +// TestFindBadField_CompletesWithoutPanic tests that findBadField completes without panicking +func TestFindBadField_CompletesWithoutPanic(t *testing.T) { + execCtx := exec.NewContext("GET", "/test", nil, "") + + // Should complete without panicking + assert.NotPanics(t, func() { + findBadField(execCtx) + }) +} + +// TestSafeMarshal_WithLabel tests that safeMarshal uses the label parameter in error messages +func TestSafeMarshal_WithLabel(t *testing.T) { + // Capture stdout to verify label is used in error messages + oldStdout := os.Stdout + r, w, _ := os.Pipe() + os.Stdout = w + + // Use a value that will cause an error + ch := make(chan int) + result := safeMarshal("CUSTOM_LABEL", ch) + + // Close write pipe and restore stdout + w.Close() + os.Stdout = oldStdout + + // Read captured output + var buf bytes.Buffer + io.Copy(&buf, r) + output := buf.String() + + assert.Nil(t, result, "should return nil on error") + if strings.Contains(output, "[LOG-MARSHAL-ERROR]") { + assert.Contains(t, output, "CUSTOM_LABEL", "error message should include the label") + } +} + +// TestSafeMarshal_RecoversFromPanic tests that safeMarshal properly recovers from panics +func TestSafeMarshal_RecoversFromPanic(t *testing.T) { + // Capture stdout + oldStdout := os.Stdout + r, w, _ := os.Pipe() + os.Stdout = w + + // Create a type that will panic during JSON marshaling + type PanicType struct { + Value func() // Functions cannot be marshaled + } + + panicValue := PanicType{ + Value: func() {}, + } + + // This should not cause the test to panic + result := safeMarshal("PANIC_TEST", panicValue) + + // Close write pipe and restore stdout + w.Close() + os.Stdout = oldStdout + + // Read captured output + var buf bytes.Buffer + io.Copy(&buf, r) + output := buf.String() + + // Function should recover and return nil + assert.Nil(t, result, "should return nil after recovering from panic") + // Should log the panic + if strings.Contains(output, "[LOG-MARSHAL-PANIC]") { + assert.Contains(t, output, "PANIC_TEST", "panic log should include label") + } +} + +// TestSafeMarshal_ExecContextPanicCallsFindBadField tests that safeMarshal calls findBadField when exec.Context panics +func TestSafeMarshal_ExecContextPanicCallsFindBadField(t *testing.T) { + // Capture stdout + oldStdout := os.Stdout + r, w, _ := os.Pipe() + os.Stdout = w + + execCtx := exec.NewContext("GET", "/test", nil, "") + + // Try to marshal - if it panics, findBadField should be called + result := safeMarshal("EXECCONTEXT", execCtx) + + // Close write pipe and restore stdout + w.Close() + os.Stdout = oldStdout + + // Read captured output + var buf bytes.Buffer + io.Copy(&buf, r) + output := buf.String() + + // If marshaling panicked, findBadField should have been called + if result == nil && strings.Contains(output, "[LOG-MARSHAL-PANIC]") { + // findBadField should have been called (though output may be empty if no bad fields found) + // The important thing is that the function didn't crash + assert.True(t, true, "findBadField should be called when exec.Context panics") + } +} diff --git a/repository/option.go b/repository/option.go index c660a7486..9c2b9b34b 100644 --- a/repository/option.go +++ b/repository/option.go @@ -43,6 +43,8 @@ type Options struct { constants map[string]string substitutes map[string]view.Substitutes authConfig aconfig.Config + shapePipeline bool + legacyTypeContext bool } func (o *Options) UseColumn() bool { @@ -242,6 +244,23 @@ func WithPath(aPath *path.Path) Option { } } +// WithShapePipeline enables the repository/shape scan->plan->load pipeline +// during components initialization. +// The default is false to preserve existing behavior. +func WithShapePipeline(enabled bool) Option { + return func(o *Options) { + o.shapePipeline = enabled + } +} + +// WithLegacyTypeContext enables TypeContext enrichment in legacy repository runtime. +// Disabled by default for rollback safety. +func WithLegacyTypeContext(enabled bool) Option { + return func(o *Options) { + o.legacyTypeContext = enabled + } +} + func WithJWTSigner(aSigner *signer.Config) Option { return func(o *Options) { o.authConfig.JwtSigner = aSigner diff --git a/repository/option_shape_test.go b/repository/option_shape_test.go new file mode 100644 index 000000000..11bf4ecba --- /dev/null +++ b/repository/option_shape_test.go @@ -0,0 +1,29 @@ +package repository + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestWithShapePipeline(t *testing.T) { + opts := NewOptions(nil) + assert.False(t, opts.shapePipeline) + + WithShapePipeline(true)(opts) + assert.True(t, opts.shapePipeline) + + WithShapePipeline(false)(opts) + assert.False(t, opts.shapePipeline) +} + +func TestWithLegacyTypeContext(t *testing.T) { + opts := NewOptions(nil) + assert.False(t, opts.legacyTypeContext) + + WithLegacyTypeContext(true)(opts) + assert.True(t, opts.legacyTypeContext) + + WithLegacyTypeContext(false)(opts) + assert.False(t, opts.legacyTypeContext) +} diff --git a/repository/path/service.go b/repository/path/service.go index 260fff91f..2ce530d9d 100644 --- a/repository/path/service.go +++ b/repository/path/service.go @@ -198,7 +198,8 @@ func (s *Service) buildPaths(ctx context.Context, candidate storage.Object, root } sourceURL := candidate.URL() if index := strings.Index(sourceURL, rootPath); index != -1 { - sourceURL = sourceURL[1+index+len(rootPath):] + sourceURL = sourceURL[index+len(rootPath):] + sourceURL = strings.TrimPrefix(sourceURL, "/") } anItem := &Item{ SourceURL: sourceURL, diff --git a/repository/provider.go b/repository/provider.go index ffe166891..b294fcc10 100644 --- a/repository/provider.go +++ b/repository/provider.go @@ -26,7 +26,9 @@ func (p *Provider) Component(ctx context.Context, opts ...Option) (*Component, e p.mux.Lock() defer p.mux.Unlock() if p.control.ChangeKind() == version.ChangeKindDeleted { - //TODO maybe return 404 error + if p.component != nil { + return p.component, nil + } return nil, nil } aComponent, err := p.newComponent(ctx, opts...) diff --git a/repository/resource/service.go b/repository/resource/service.go index d1104e891..19d098b54 100644 --- a/repository/resource/service.go +++ b/repository/resource/service.go @@ -3,6 +3,10 @@ package resource import ( "context" "fmt" + "strings" + "sync" + "time" + "github.com/viant/afs" "github.com/viant/afs/file" "github.com/viant/afs/storage" @@ -10,9 +14,6 @@ import ( "github.com/viant/cloudless/resource" "github.com/viant/datly/repository/version" "github.com/viant/datly/view" - "strings" - "sync" - "time" ) type ( diff --git a/repository/shape/README.md b/repository/shape/README.md new file mode 100644 index 000000000..30848cef5 --- /dev/null +++ b/repository/shape/README.md @@ -0,0 +1,94 @@ +# repository/shape + +`repository/shape` provides a dynamic, in-memory pipeline for building Datly runtime artifacts from either: + +- Go structs (`scan -> plan -> load`) +- DQL (`compile -> load`) + +without generating YAML route/resource files. + +## Packages + +- `shape/scan`: discovers view/state tags from struct fields (Embedder-aware). +- `shape/plan`: normalizes scan output into a deterministic shape plan. +- `shape/load`: materializes `view.Resource`, `view.View`, and a runtime-neutral component artifact. +- `shape/compile`: compiles DQL into a shape plan for dynamic loading. + +## Facade API + +Use `shape.Engine` or package helpers: + +- `shape.LoadViews(ctx, src, opts...)` +- `shape.LoadComponent(ctx, src, opts...)` +- `shape.LoadDQLViews(ctx, dql, opts...)` +- `shape.LoadDQLComponent(ctx, dql, opts...)` + +## Minimal Struct Flow + +```go +engine := shape.New( + shape.WithScanner(scan.New()), + shape.WithPlanner(plan.New()), + shape.WithLoader(load.New()), + shape.WithName("/v1/api/report"), +) + +views, err := engine.LoadViews(ctx, &MyOutput{}) +``` + +## Minimal DQL Flow + +```go +engine := shape.New( + shape.WithCompiler(compile.New()), + shape.WithLoader(load.New()), + shape.WithName("/v1/api/report"), +) + +component, err := engine.LoadDQLComponent(ctx, "SELECT id FROM ORDERS t") +``` + +## DQL Directives + +`shape` recognizes three directive forms in DQL: + +- `#set(...)`: contract declarations (legacy-compatible). +- `#define(...)`: contract declarations (alias of `#set(...)` for clearer intent). +- `#settings(...)` / `#setting(...)`: runtime/settings directives. + +Runtime/settings directives currently support: + +- `#settings($_ = $package('module/path'))` +- `#settings($_ = $import('alias', 'github.com/acme/pkg'))` +- `#settings($_ = $meta('docs/path.md'))` +- `#settings($_ = $cache(true, '5m'))` +- `#settings($_ = $mcp('tool.name', 'description', 'docs/mcp/tool.md'))` +- `#settings($_ = $connector('analytics'))` (default connector for views that do not already declare one) + +## Column Discovery Policy + +Shape compile now exposes column discovery policy for DQL->IR: + +- `auto` (default): require discovery for `SELECT *` and for views without concrete declared shape. +- `on`: always mark query views for discovery. +- `off`: disable discovery; compile fails when discovery is required. + +Use `shape.WithColumnDiscoveryModeDefault(...)` on engine defaults or `shape.WithColumnDiscoveryMode(...)` as compile option. + +## Repository Integration + +`repository/components.go` can optionally merge views generated by the shape pipeline during init. + +Enable via: + +```go +repository.WithShapePipeline(true) +``` + +Default is disabled to preserve existing behavior. + +## Component Contract Parity + +Cross-component contract/signature parity target is documented in: + +- `compile/COMPONENT_CONTRACT_PARITY.md` diff --git a/repository/shape/column/detector.go b/repository/shape/column/detector.go new file mode 100644 index 000000000..79b6c8d1c --- /dev/null +++ b/repository/shape/column/detector.go @@ -0,0 +1,237 @@ +package column + +import ( + "context" + "fmt" + "reflect" + "strings" + + "github.com/viant/datly/view" + viewcolumn "github.com/viant/datly/view/column" + "github.com/viant/sqlparser" + "github.com/viant/sqlx/io" +) + +// Detector resolves columns for shape-generated views. +// +// Rules: +// - schema field order is canonical order +// - wildcard SQL always performs DB discovery +// - newly discovered columns are appended at the end +// - matched columns keep schema order but refresh metadata from DB +type Detector struct{} + +func New() *Detector { + return &Detector{} +} + +func (d *Detector) Resolve(ctx context.Context, resource *view.Resource, aView *view.View) (view.Columns, error) { + if aView == nil { + return nil, fmt.Errorf("shape column detector: nil view") + } + + base := columnsFromSchema(aView) + if !usesWildcard(aView) { + return base, nil + } + + discovered, err := d.detect(ctx, resource, aView) + if err != nil { + return nil, err + } + if len(base) == 0 { + return discovered, nil + } + return mergePreservingOrder(base, discovered), nil +} + +func (d *Detector) detect(ctx context.Context, resource *view.Resource, aView *view.View) (view.Columns, error) { + connector, err := lookupConnector(ctx, resource, aView) + if err != nil { + return nil, err + } + db, err := connector.DB() + if err != nil { + return nil, fmt.Errorf("shape column detector: failed to open db for view %s: %w", aView.Name, err) + } + query := sourceSQL(aView) + sqlColumns, err := viewcolumn.Discover(ctx, db, aView.Table, query) + if err != nil { + return nil, fmt.Errorf("shape column detector: discover failed for view %s: %w", aView.Name, err) + } + return view.NewColumns(sqlColumns, aView.ColumnsConfig), nil +} + +func lookupConnector(ctx context.Context, resource *view.Resource, aView *view.View) (*view.Connector, error) { + if resource == nil { + return nil, fmt.Errorf("shape column detector: missing resource for view %s", aView.Name) + } + if aView.Connector == nil { + return nil, fmt.Errorf("shape column detector: missing connector for wildcard view %s", aView.Name) + } + connectors := view.ConnectorSlice(resource.Connectors).Index() + connector := aView.Connector + if connector.Ref != "" { + lookup, err := connectors.Lookup(connector.Ref) + if err != nil { + return nil, fmt.Errorf("shape column detector: connector ref %s for view %s: %w", connector.Ref, aView.Name, err) + } + connector = lookup + } + if err := connector.Init(ctx, connectors); err != nil { + return nil, fmt.Errorf("shape column detector: connector init for view %s: %w", aView.Name, err) + } + return connector, nil +} + +func sourceSQL(aView *view.View) string { + if aView.Template != nil && strings.TrimSpace(aView.Template.Source) != "" { + return aView.Template.Source + } + return aView.Source() +} + +func usesWildcard(aView *view.View) bool { + if aView != nil && aView.Template == nil && strings.TrimSpace(aView.Table) != "" { + return true + } + query := sourceSQL(aView) + trimmed := strings.TrimSpace(strings.ToLower(query)) + if trimmed == "" { + return false + } + if !strings.Contains(trimmed, "*") { + return false + } + if !strings.HasPrefix(trimmed, "select") && !strings.HasPrefix(trimmed, "with") { + return true + } + parsed, err := sqlparser.ParseQuery(query) + if err != nil { + return true + } + return sqlparser.NewColumns(parsed.List).IsStarExpr() +} + +func columnsFromSchema(aView *view.View) view.Columns { + if aView == nil || aView.Schema == nil { + return nil + } + rType := aView.Schema.Type() + if rType == nil { + return nil + } + for rType.Kind() == reflect.Ptr || rType.Kind() == reflect.Slice { + rType = rType.Elem() + } + if rType.Kind() != reflect.Struct { + return nil + } + result := make(view.Columns, 0, rType.NumField()) + appendSchemaColumns(rType, "", &result) + return result +} + +func appendSchemaColumns(rType reflect.Type, ns string, columns *view.Columns) { + for i := 0; i < rType.NumField(); i++ { + field := rType.Field(i) + if field.PkgPath != "" { // unexported + continue + } + if field.Anonymous { + inner := field.Type + for inner.Kind() == reflect.Ptr { + inner = inner.Elem() + } + if inner.Kind() == reflect.Struct { + appendSchemaColumns(inner, ns, columns) + } + continue + } + + tag := io.ParseTag(field.Tag) + if tag != nil && tag.Transient { + continue + } + + name := field.Name + if tag != nil && tag.Column != "" { + name = tag.Column + } + if tag != nil && tag.Ns != "" { + name = tag.Ns + name + } else if ns != "" { + name = ns + name + } + + columnType := field.Type + nullable := false + if columnType.Kind() == reflect.Ptr { + nullable = true + columnType = columnType.Elem() + } + *columns = append(*columns, view.NewColumn(name, columnType.String(), columnType, nullable, view.WithColumnTag(string(field.Tag)))) + } +} + +func mergePreservingOrder(base, discovered view.Columns) view.Columns { + if len(base) == 0 { + return discovered + } + if len(discovered) == 0 { + return base + } + seen := map[string]*view.Column{} + for _, item := range discovered { + if item == nil { + continue + } + seen[strings.ToLower(item.Name)] = item + } + result := make(view.Columns, 0, len(base)+len(discovered)) + for _, item := range base { + if item == nil { + continue + } + if fresh, ok := seen[strings.ToLower(item.Name)]; ok { + delete(seen, strings.ToLower(item.Name)) + // Keep schema name/order but refresh discovered metadata. + item.DataType = firstNonEmpty(fresh.DataType, item.DataType) + item.SetColumnType(firstType(fresh.ColumnType(), item.ColumnType())) + item.Nullable = fresh.Nullable + if item.DatabaseColumn == "" { + item.DatabaseColumn = fresh.DatabaseColumn + } + } + result = append(result, item) + } + for _, item := range discovered { + if item == nil { + continue + } + if _, ok := seen[strings.ToLower(item.Name)]; !ok { + continue + } + result = append(result, item) + delete(seen, strings.ToLower(item.Name)) + } + return result +} + +func firstNonEmpty(values ...string) string { + for _, value := range values { + if strings.TrimSpace(value) != "" { + return value + } + } + return "" +} + +func firstType(values ...reflect.Type) reflect.Type { + for _, value := range values { + if value != nil { + return value + } + } + return nil +} diff --git a/repository/shape/column/detector_sqlite_test.go b/repository/shape/column/detector_sqlite_test.go new file mode 100644 index 000000000..f162b0fa1 --- /dev/null +++ b/repository/shape/column/detector_sqlite_test.go @@ -0,0 +1,57 @@ +package column + +import ( + "context" + "database/sql" + "path/filepath" + "reflect" + "strings" + "testing" + + _ "github.com/mattn/go-sqlite3" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/viant/datly/view" + "github.com/viant/datly/view/state" +) + +type sqliteOrder struct { + VendorID int `sqlx:"name=VENDOR_ID"` + Name string `sqlx:"name=NAME"` +} + +func TestDetector_Resolve_SQLiteWildcard(t *testing.T) { + ctx := context.Background() + dsn := filepath.Join(t.TempDir(), "shape_detector.sqlite") + db, err := sql.Open("sqlite3", dsn) + require.NoError(t, err) + t.Cleanup(func() { _ = db.Close() }) + + _, err = db.ExecContext(ctx, `CREATE TABLE VENDOR (VENDOR_ID INTEGER NOT NULL, NAME TEXT NOT NULL, STATUS TEXT)`) + require.NoError(t, err) + + resource := view.EmptyResource() + resource.Connectors = []*view.Connector{{Connection: view.Connection{DBConfig: view.DBConfig{Name: "db", Driver: "sqlite3", DSN: dsn}}}} + + aView := &view.View{ + Name: "vendor", + Table: "VENDOR", + Schema: state.NewSchema(reflect.TypeOf(sqliteOrder{}), state.WithMany()), + Template: view.NewTemplate("SELECT * FROM VENDOR"), + Connector: view.NewRefConnector("db"), + } + + resolved, err := New().Resolve(ctx, resource, aView) + require.NoError(t, err) + require.GreaterOrEqual(t, len(resolved), 3) + + // Schema order is preserved, discovered extra columns are appended. + assert.Equal(t, "VENDOR_ID", strings.ToUpper(resolved[0].Name)) + assert.Equal(t, "NAME", strings.ToUpper(resolved[1].Name)) + + names := make([]string, 0, len(resolved)) + for _, item := range resolved { + names = append(names, strings.ToUpper(item.Name)) + } + assert.Contains(t, names, "STATUS") +} diff --git a/repository/shape/column/detector_test.go b/repository/shape/column/detector_test.go new file mode 100644 index 000000000..cfc834b11 --- /dev/null +++ b/repository/shape/column/detector_test.go @@ -0,0 +1,59 @@ +package column + +import ( + "reflect" + "testing" + + "github.com/stretchr/testify/require" + "github.com/viant/datly/view" + "github.com/viant/datly/view/state" +) + +type sampleOrder struct { + VendorID int `sqlx:"name=VENDOR_ID"` + Name string `sqlx:"name=NAME"` +} + +func TestUsesWildcard(t *testing.T) { + tests := []struct { + name string + view *view.View + want bool + }{ + {name: "select wildcard", view: &view.View{Template: view.NewTemplate("SELECT * FROM VENDOR")}, want: true}, + {name: "select explicit", view: &view.View{Template: view.NewTemplate("SELECT ID, NAME FROM VENDOR")}, want: false}, + {name: "table only", view: &view.View{Table: "VENDOR"}, want: true}, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + require.Equal(t, tc.want, usesWildcard(tc.view)) + }) + } +} + +func TestColumnsFromSchema_Order(t *testing.T) { + aView := &view.View{Schema: state.NewSchema(reflect.TypeOf(sampleOrder{}), state.WithMany())} + cols := columnsFromSchema(aView) + require.Len(t, cols, 2) + require.Equal(t, "VENDOR_ID", cols[0].Name) + require.Equal(t, "NAME", cols[1].Name) +} + +func TestMergePreservingOrder_AppendsNewDetectedColumns(t *testing.T) { + base := view.Columns{ + view.NewColumn("VENDOR_ID", "int", reflect.TypeOf(int(0)), false), + view.NewColumn("NAME", "varchar", reflect.TypeOf(""), false), + } + detected := view.Columns{ + view.NewColumn("NAME", "text", reflect.TypeOf(""), true), + view.NewColumn("VENDOR_ID", "bigint", reflect.TypeOf(int64(0)), false), + view.NewColumn("STATUS", "int", reflect.TypeOf(int(0)), true), + } + merged := mergePreservingOrder(base, detected) + require.Len(t, merged, 3) + require.Equal(t, "VENDOR_ID", merged[0].Name) + require.Equal(t, "NAME", merged[1].Name) + require.Equal(t, "STATUS", merged[2].Name) + require.Equal(t, "bigint", merged[0].DataType) + require.Equal(t, "text", merged[1].DataType) +} diff --git a/repository/shape/compile/column_discovery_policy.go b/repository/shape/compile/column_discovery_policy.go new file mode 100644 index 000000000..8cb908167 --- /dev/null +++ b/repository/shape/compile/column_discovery_policy.go @@ -0,0 +1,113 @@ +package compile + +import ( + "reflect" + "strings" + + "github.com/viant/datly/repository/shape" + dqldiag "github.com/viant/datly/repository/shape/dql/diag" + dqlshape "github.com/viant/datly/repository/shape/dql/shape" + "github.com/viant/datly/repository/shape/plan" + "github.com/viant/sqlparser" +) + +func applyColumnDiscoveryPolicy(result *plan.Result, compileOptions *shape.CompileOptions) []*dqlshape.Diagnostic { + if result == nil { + return nil + } + mode := normalizeColumnDiscoveryMode(shape.CompileColumnDiscoveryAuto) + if compileOptions != nil { + mode = normalizeColumnDiscoveryMode(compileOptions.ColumnDiscoveryMode) + } + + var diags []*dqlshape.Diagnostic + for _, item := range result.Views { + if item == nil || !isQueryLikeMode(item.Mode) { + continue + } + required := mode == shape.CompileColumnDiscoveryOn + if requiresColumnDiscovery(item) { + required = true + } + item.ColumnsDiscovery = required + if !required { + continue + } + result.ColumnsDiscovery = true + if mode == shape.CompileColumnDiscoveryOff { + diags = append(diags, &dqlshape.Diagnostic{ + Code: dqldiag.CodeColDiscoveryReq, + Severity: dqlshape.SeverityError, + Message: "column discovery is required but disabled", + Hint: "enable column discovery or declare an explicit shape/type without wildcard projection", + Span: dqlshape.Span{ + Start: dqlshape.Position{Line: 1, Char: 1}, + End: dqlshape.Position{Line: 1, Char: 1}, + }, + }) + } + } + return diags +} + +func normalizeColumnDiscoveryMode(mode shape.CompileColumnDiscoveryMode) shape.CompileColumnDiscoveryMode { + switch mode { + case shape.CompileColumnDiscoveryAuto, shape.CompileColumnDiscoveryOn, shape.CompileColumnDiscoveryOff: + return mode + default: + return shape.CompileColumnDiscoveryAuto + } +} + +func isQueryLikeMode(mode string) bool { + mode = strings.TrimSpace(mode) + if mode == "" { + return true + } + return strings.EqualFold(mode, "SQLQuery") +} + +func requiresColumnDiscovery(item *plan.View) bool { + if item == nil { + return false + } + if usesWildcardSQL(item.SQL, item.Table) { + return true + } + return !hasConcreteShape(item) +} + +func hasConcreteShape(item *plan.View) bool { + if item == nil { + return false + } + rType := item.ElementType + if rType == nil { + rType = item.FieldType + } + if rType == nil { + return false + } + for rType.Kind() == reflect.Ptr || rType.Kind() == reflect.Slice || rType.Kind() == reflect.Array { + rType = rType.Elem() + } + return rType.Kind() == reflect.Struct +} + +func usesWildcardSQL(sqlText, table string) bool { + if strings.TrimSpace(sqlText) == "" { + return strings.TrimSpace(table) != "" + } + lower := strings.ToLower(sqlText) + if !strings.Contains(lower, "*") { + return false + } + if !strings.HasPrefix(strings.TrimSpace(lower), "select") && !strings.HasPrefix(strings.TrimSpace(lower), "with") { + return true + } + parsed, err := sqlparser.ParseQuery(sqlText) + if err != nil { + return true + } + return sqlparser.NewColumns(parsed.List).IsStarExpr() +} diff --git a/repository/shape/compile/column_discovery_policy_test.go b/repository/shape/compile/column_discovery_policy_test.go new file mode 100644 index 000000000..72baa5c5c --- /dev/null +++ b/repository/shape/compile/column_discovery_policy_test.go @@ -0,0 +1,77 @@ +package compile + +import ( + "reflect" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/viant/datly/repository/shape" + dqldiag "github.com/viant/datly/repository/shape/dql/diag" + "github.com/viant/datly/repository/shape/plan" +) + +func TestApplyColumnDiscoveryPolicy_Auto_WildcardRequiresDiscovery(t *testing.T) { + result := &plan.Result{ + Views: []*plan.View{{ + Name: "orders", + Mode: "SQLQuery", + SQL: "SELECT * FROM ORDERS", + FieldType: reflect.TypeOf([]struct{ ID int }{}), + ElementType: reflect.TypeOf(struct{ ID int }{}), + }}, + } + diags := applyColumnDiscoveryPolicy(result, &shape.CompileOptions{ColumnDiscoveryMode: shape.CompileColumnDiscoveryAuto}) + require.Empty(t, diags) + require.True(t, result.ColumnsDiscovery) + require.True(t, result.Views[0].ColumnsDiscovery) +} + +func TestApplyColumnDiscoveryPolicy_Auto_NoConcreteShapeRequiresDiscovery(t *testing.T) { + result := &plan.Result{ + Views: []*plan.View{{ + Name: "orders", + Mode: "SQLQuery", + SQL: "SELECT id FROM ORDERS", + FieldType: reflect.TypeOf([]map[string]any{}), + ElementType: reflect.TypeOf(map[string]any{}), + }}, + } + diags := applyColumnDiscoveryPolicy(result, &shape.CompileOptions{ColumnDiscoveryMode: shape.CompileColumnDiscoveryAuto}) + require.Empty(t, diags) + require.True(t, result.ColumnsDiscovery) + require.True(t, result.Views[0].ColumnsDiscovery) +} + +func TestApplyColumnDiscoveryPolicy_Off_EmitsErrorWhenRequired(t *testing.T) { + result := &plan.Result{ + Views: []*plan.View{{ + Name: "orders", + Mode: "SQLQuery", + SQL: "SELECT * FROM ORDERS", + FieldType: reflect.TypeOf([]map[string]any{}), + ElementType: reflect.TypeOf(map[string]any{}), + }}, + } + diags := applyColumnDiscoveryPolicy(result, &shape.CompileOptions{ColumnDiscoveryMode: shape.CompileColumnDiscoveryOff}) + require.NotEmpty(t, diags) + assert.Equal(t, dqldiag.CodeColDiscoveryReq, diags[0].Code) + assert.True(t, result.ColumnsDiscovery) + assert.True(t, result.Views[0].ColumnsDiscovery) +} + +func TestApplyColumnDiscoveryPolicy_On_AlwaysMarksQueryViews(t *testing.T) { + result := &plan.Result{ + Views: []*plan.View{{ + Name: "orders", + Mode: "SQLQuery", + SQL: "SELECT id FROM ORDERS", + FieldType: reflect.TypeOf([]struct{ ID int }{}), + ElementType: reflect.TypeOf(struct{ ID int }{}), + }}, + } + diags := applyColumnDiscoveryPolicy(result, &shape.CompileOptions{ColumnDiscoveryMode: shape.CompileColumnDiscoveryOn}) + require.Empty(t, diags) + assert.True(t, result.ColumnsDiscovery) + assert.True(t, result.Views[0].ColumnsDiscovery) +} diff --git a/repository/shape/compile/compiler.go b/repository/shape/compile/compiler.go new file mode 100644 index 000000000..6fb6eadc0 --- /dev/null +++ b/repository/shape/compile/compiler.go @@ -0,0 +1,274 @@ +package compile + +import ( + "context" + "fmt" + "strings" + + "github.com/viant/datly/repository/shape" + "github.com/viant/datly/repository/shape/compile/dml" + "github.com/viant/datly/repository/shape/compile/pipeline" + dqldiag "github.com/viant/datly/repository/shape/dql/diag" + dqlpre "github.com/viant/datly/repository/shape/dql/preprocess" + dqlshape "github.com/viant/datly/repository/shape/dql/shape" + dqlstmt "github.com/viant/datly/repository/shape/dql/statement" + "github.com/viant/datly/repository/shape/plan" +) + +// DQLCompiler compiles raw DQL into a shape plan that can be materialized by shape/load. +type DQLCompiler struct{} + +// New returns a DQL compiler implementation. +func New() *DQLCompiler { + return &DQLCompiler{} +} + +// CompileError represents one or more compilation diagnostics. +type CompileError struct { + Diagnostics []*dqlshape.Diagnostic +} + +func (e *CompileError) Error() string { + if e == nil || len(e.Diagnostics) == 0 { + return "shape compile failed" + } + first := e.Diagnostics[0] + if len(e.Diagnostics) == 1 { + return first.Error() + } + return fmt.Sprintf("%s (and %d more diagnostics)", first.Error(), len(e.Diagnostics)-1) +} + +// Compile implements shape.DQLCompiler. +func (c *DQLCompiler) Compile(ctx context.Context, source *shape.Source, opts ...shape.CompileOption) (*shape.PlanResult, error) { + if err := ctx.Err(); err != nil { + return nil, err + } + if source == nil { + return nil, shape.ErrNilSource + } + compileOptions := applyCompileOptions(opts) + pathLayout := newCompilePathLayout(compileOptions) + enforceStrict := compileOptions.Strict || normalizeCompileProfile(compileOptions.Profile) == shape.CompileProfileStrict + + prepared, allDiags, err := c.preprocessSource(source, compileOptions, pathLayout, enforceStrict) + if err != nil { + return nil, err + } + + root, compileDiags, err := c.compileRoot( + source.Name, prepared.Pre.SQL, prepared.Statements, prepared.Decision, + compileOptions.MixedMode, compileOptions.UnknownNonReadMode, + ) + if err != nil { + return nil, err + } + prepared.Pre.Mapper.Remap(compileDiags) + allDiags = append(allDiags, compileDiags...) + if root == nil { + return nil, &CompileError{Diagnostics: allDiags} + } + + result := c.assembleResult(source, root, prepared, compileOptions, pathLayout, allDiags) + if enforceStrict && hasEscalationWarnings(result.Diagnostics) { + return nil, &CompileError{Diagnostics: filterEscalationDiagnostics(result.Diagnostics)} + } + if hasErrorDiagnostics(result.Diagnostics) { + return nil, &CompileError{Diagnostics: result.Diagnostics} + } + return &shape.PlanResult{Source: source, Plan: result}, nil +} + +// preprocessSource runs DQL preprocessing (type context, directives, handler +// detection) and returns a ready-to-compile prepared result with accumulated +// diagnostics. Returns an error only for fatal early failures. +func (c *DQLCompiler) preprocessSource( + source *shape.Source, + compileOptions *shape.CompileOptions, + pathLayout compilePathLayout, + enforceStrict bool, +) (*handlerPreprocessResult, []*dqlshape.Diagnostic, error) { + if strings.TrimSpace(source.DQL) == "" { + return nil, nil, shape.ErrNilDQL + } + pre := dqlpre.Prepare(source.DQL) + pre.TypeCtx = applyTypeContextDefaults(pre.TypeCtx, source, compileOptions, pathLayout) + pre.Diagnostics = append(pre.Diagnostics, typeContextDiagnostics(pre.TypeCtx, enforceStrict)...) + allDiags := append([]*dqlshape.Diagnostic{}, pre.Diagnostics...) + if hasErrorDiagnostics(allDiags) { + return nil, nil, &CompileError{Diagnostics: allDiags} + } + + statements := dqlstmt.New(pre.SQL) + decision := pipeline.Classify(statements) + prepared := buildHandlerIfNeeded(source, pre, statements, decision, pathLayout) + if strings.TrimSpace(prepared.Pre.SQL) == "" { + allDiags = append(allDiags, &dqlshape.Diagnostic{ + Code: dqldiag.CodeParseEmpty, + Severity: dqlshape.SeverityError, + Message: "no SQL statement found", + Hint: "add SELECT/INSERT/UPDATE/DELETE statement after DQL directives", + Span: dqlshape.Span{ + Start: dqlshape.Position{Line: 1, Char: 1}, + End: dqlshape.Position{Line: 1, Char: 1}, + }, + }) + return nil, nil, &CompileError{Diagnostics: allDiags} + } + return prepared, allDiags, nil +} + +// assembleResult builds the plan.Result from the compiled root view, attaches +// declared relations/views/states, applies enrichment, and computes the final +// column-discovery policy diagnostics. +func (c *DQLCompiler) assembleResult( + source *shape.Source, + root *plan.View, + prepared *handlerPreprocessResult, + compileOptions *shape.CompileOptions, + pathLayout compilePathLayout, + diags []*dqlshape.Diagnostic, +) *plan.Result { + result := newPlanResult(root) + result.Diagnostics = diags + result.TypeContext = prepared.Pre.TypeCtx + result.Directives = prepared.Pre.Directives + applyDefaultConnectorDirective(result) + hints := extractViewHints(source.DQL) + appendRelationViews(result, root, hints) + appendDeclaredViews(source.DQL, result) + appendDeclaredStates(source.DQL, result) + applyViewHints(result, hints) + applySourceParityEnrichmentWithLayout(result, source, pathLayout) + applyLinkedTypeSupport(result, source) + result.Diagnostics = append(result.Diagnostics, applyColumnDiscoveryPolicy(result, compileOptions)...) + return result +} + +func applyDefaultConnectorDirective(result *plan.Result) { + if result == nil || result.Directives == nil { + return + } + connector := strings.TrimSpace(result.Directives.DefaultConnector) + if connector == "" { + return + } + for _, item := range result.Views { + if item == nil || strings.TrimSpace(item.Connector) != "" { + continue + } + item.Connector = connector + } +} + +func (c *DQLCompiler) compileRoot(sourceName, sqlText string, statements dqlstmt.Statements, decision pipeline.Decision, mode shape.CompileMixedMode, unknownMode shape.CompileUnknownNonReadMode) (*plan.View, []*dqlshape.Diagnostic, error) { + mode = normalizeMixedMode(mode) + unknownMode = normalizeUnknownNonReadMode(unknownMode) + if !decision.HasRead && !decision.HasExec && decision.HasUnknown { + diag := &dqlshape.Diagnostic{ + Code: dqldiag.CodeParseUnknownNonRead, + Severity: dqlshape.SeverityWarning, + Message: "no readable SELECT statement detected", + Hint: "use SELECT for read parsing or compile as DML/handler template", + Span: pipeline.StatementSpan(sqlText, statements[0]), + } + if unknownMode == shape.CompileUnknownNonReadError { + diag.Severity = dqlshape.SeverityError + return nil, []*dqlshape.Diagnostic{diag}, nil + } + view, execDiags := pipeline.BuildExec(sourceName, sqlText, statements) + return view, append([]*dqlshape.Diagnostic{diag}, execDiags...), nil + } + if decision.HasRead && decision.HasExec { + switch mode { + case shape.CompileMixedModeErrorOnMixed: + return nil, []*dqlshape.Diagnostic{ + { + Code: dqldiag.CodeDMLMixed, + Severity: dqlshape.SeverityError, + Message: "mixed read/exec script is not allowed by compile mixed mode", + Hint: "use WithMixedMode(shape.CompileMixedModeExecWins) or split handlers", + Span: pipeline.StatementSpan(sqlText, statements[0]), + }, + }, nil + case shape.CompileMixedModeReadWins: + readSQL := sqlText + for _, stmt := range statements { + if stmt != nil && stmt.Kind == dqlstmt.KindRead { + readSQL = sqlText[stmt.Start:stmt.End] + break + } + } + view, diags, err := pipeline.BuildRead(sourceName, readSQL) + diags = append(diags, &dqlshape.Diagnostic{ + Code: dqldiag.CodeDMLMixed, + Severity: dqlshape.SeverityWarning, + Message: "mixed read/exec script detected; read compilation path selected", + Hint: "split SELECT and DML into separate handlers when possible", + Span: pipeline.StatementSpan(sqlText, statements[0]), + }) + return view, diags, err + } + } + if decision.HasExec { + view, diags := dml.Compile(sourceName, sqlText, statements) + if decision.HasRead { + diags = append(diags, &dqlshape.Diagnostic{ + Code: dqldiag.CodeDMLMixed, + Severity: dqlshape.SeverityWarning, + Message: "mixed read/exec script detected; exec compilation path selected", + Hint: "split SELECT and DML into separate handlers when possible", + Span: pipeline.StatementSpan(sqlText, statements[0]), + }) + } + return view, diags, nil + } + return pipeline.BuildRead(sourceName, sqlText) +} + +func normalizeMixedMode(mode shape.CompileMixedMode) shape.CompileMixedMode { + switch mode { + case shape.CompileMixedModeExecWins, shape.CompileMixedModeReadWins, shape.CompileMixedModeErrorOnMixed: + return mode + default: + return shape.CompileMixedModeExecWins + } +} + +func normalizeUnknownNonReadMode(mode shape.CompileUnknownNonReadMode) shape.CompileUnknownNonReadMode { + switch mode { + case shape.CompileUnknownNonReadWarn, shape.CompileUnknownNonReadError: + return mode + default: + return shape.CompileUnknownNonReadWarn + } +} + +func normalizeCompileProfile(profile shape.CompileProfile) shape.CompileProfile { + switch profile { + case shape.CompileProfileCompat, shape.CompileProfileStrict: + return profile + default: + return shape.CompileProfileCompat + } +} + +func newPlanResult(root *plan.View) *plan.Result { + result := &plan.Result{ + Views: []*plan.View{root}, + ViewsByName: map[string]*plan.View{}, + ByPath: map[string]*plan.Field{}, + } + result.ViewsByName[root.Name] = root + return result +} + +func applyCompileOptions(opts []shape.CompileOption) *shape.CompileOptions { + ret := &shape.CompileOptions{} + for _, opt := range opts { + if opt != nil { + opt(ret) + } + } + return ret +} diff --git a/repository/shape/compile/compiler_test.go b/repository/shape/compile/compiler_test.go new file mode 100644 index 000000000..85b51108c --- /dev/null +++ b/repository/shape/compile/compiler_test.go @@ -0,0 +1,837 @@ +package compile + +import ( + "context" + "os" + "path/filepath" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/viant/datly/repository/shape" + dqldiag "github.com/viant/datly/repository/shape/dql/diag" + dqlshape "github.com/viant/datly/repository/shape/dql/shape" + "github.com/viant/datly/repository/shape/plan" +) + +func TestDQLCompiler_Compile(t *testing.T) { + compiler := New() + res, err := compiler.Compile(context.Background(), &shape.Source{Name: "orders_report", DQL: "SELECT id FROM ORDERS t"}) + require.NoError(t, err) + require.NotNil(t, res) + + planned, ok := plan.ResultFrom(res) + require.True(t, ok) + require.Len(t, planned.Views, 1) + view := planned.Views[0] + assert.Equal(t, "t", view.Name) + assert.Equal(t, "ORDERS", view.Table) + assert.Equal(t, "many", view.Cardinality) + require.NotNil(t, view.FieldType) + assert.Contains(t, view.FieldType.String(), "Id") +} + +func TestDQLCompiler_Compile_EmptyDQL(t *testing.T) { + compiler := New() + _, err := compiler.Compile(context.Background(), &shape.Source{Name: "x"}) + require.Error(t, err) + assert.ErrorIs(t, err, shape.ErrNilDQL) +} + +func TestDQLCompiler_Compile_WithPreamble_NoPanic(t *testing.T) { + compiler := New() + dql := ` +/* metadata */ +#set($_ = $A(query/a).Optional()) +SELECT id +` + res, err := compiler.Compile(context.Background(), &shape.Source{Name: "sample_report", DQL: dql}) + require.NoError(t, err) + require.NotNil(t, res) + + planned, ok := plan.ResultFrom(res) + require.True(t, ok) + require.Len(t, planned.Views, 1) + assert.Equal(t, "sample_report", planned.Views[0].Name) + assert.Equal(t, "sample_report", planned.Views[0].Table) +} + +func TestDQLCompiler_Compile_PropagatesTypeContext(t *testing.T) { + compiler := New() + dql := ` +#package('mdp/performance') +#import('perf', 'github.com/acme/mdp/performance') +SELECT id FROM ORDERS t` + res, err := compiler.Compile(context.Background(), &shape.Source{Name: "orders_report", DQL: dql}) + require.NoError(t, err) + require.NotNil(t, res) + + planned, ok := plan.ResultFrom(res) + require.True(t, ok) + require.NotNil(t, planned.TypeContext) + assert.Equal(t, "mdp/performance", planned.TypeContext.DefaultPackage) + require.Len(t, planned.TypeContext.Imports, 1) + assert.Equal(t, "perf", planned.TypeContext.Imports[0].Alias) +} + +func TestDQLCompiler_Compile_PropagatesImportedTypeContextWithModuleNormalization(t *testing.T) { + compiler := New() + projectDir := t.TempDir() + err := os.WriteFile(filepath.Join(projectDir, "go.mod"), []byte("module github.vianttech.com/viant/platform\n\ngo 1.23\n"), 0o644) + require.NoError(t, err) + source := &shape.Source{ + Name: "orders_report", + Path: filepath.Join(projectDir, "dql", "platform", "taxonomy", "get.dql"), + DQL: "#import('session','pkg/platform/system/session')\nSELECT id FROM ORDERS t", + } + res, err := compiler.Compile(context.Background(), source) + require.NoError(t, err) + planned, ok := plan.ResultFrom(res) + require.True(t, ok) + require.NotNil(t, planned.TypeContext) + require.Len(t, planned.TypeContext.Imports, 1) + assert.Equal(t, "session", planned.TypeContext.Imports[0].Alias) + assert.Equal(t, "github.vianttech.com/viant/platform/pkg/platform/system/session", planned.TypeContext.Imports[0].Package) +} + +func TestDQLCompiler_Compile_PropagatesSpecialDirectives(t *testing.T) { + compiler := New() + dql := ` +#settings($_ = $meta('docs/orders.md')) +#settings($_ = $connector('analytics')) +#settings($_ = $cache(true, '5m')) +#settings($_ = $mcp('orders.search', 'Search orders', 'docs/mcp/orders.md')) +#settings($_ = $route('/v1/api/orders', 'GET', 'POST', 'PATCH')) +#settings($_ = $marshal('application/json','pkg.OrderJSON')) +#settings($_ = $unmarshal('application/json','pkg.OrderIn')) +#settings($_ = $unmarshal('application/xml','pkg.OrderXMLIn')) +#settings($_ = $format('tabular_json')) +#settings($_ = $date_format('2006-01-02')) +#settings($_ = $case_format('lc')) +SELECT id FROM ORDERS o +` + res, err := compiler.Compile(context.Background(), &shape.Source{Name: "orders_report", DQL: dql}) + require.NoError(t, err) + planned, ok := plan.ResultFrom(res) + require.True(t, ok) + require.NotNil(t, planned.Directives) + assert.Equal(t, "docs/orders.md", planned.Directives.Meta) + assert.Equal(t, "analytics", planned.Directives.DefaultConnector) + require.NotNil(t, planned.Directives.Cache) + assert.True(t, planned.Directives.Cache.Enabled) + assert.Equal(t, "5m", planned.Directives.Cache.TTL) + require.NotNil(t, planned.Directives.MCP) + assert.Equal(t, "orders.search", planned.Directives.MCP.Name) + assert.Equal(t, "Search orders", planned.Directives.MCP.Description) + assert.Equal(t, "docs/mcp/orders.md", planned.Directives.MCP.DescriptionPath) + require.NotNil(t, planned.Directives.Route) + assert.Equal(t, "/v1/api/orders", planned.Directives.Route.URI) + assert.Equal(t, []string{"GET", "POST", "PATCH"}, planned.Directives.Route.Methods) + assert.Equal(t, "pkg.OrderJSON", planned.Directives.JSONMarshalType) + assert.Equal(t, "pkg.OrderIn", planned.Directives.JSONUnmarshalType) + assert.Equal(t, "pkg.OrderXMLIn", planned.Directives.XMLUnmarshalType) + assert.Equal(t, "tabular", planned.Directives.Format) + assert.Equal(t, "2006-01-02", planned.Directives.DateFormat) + assert.Equal(t, "lc", planned.Directives.CaseFormat) + require.NotEmpty(t, planned.Views) + assert.Equal(t, "analytics", planned.Views[0].Connector) +} + +func TestDQLCompiler_Compile_ColumnDiscoveryAutoForWildcard(t *testing.T) { + compiler := New() + res, err := compiler.Compile(context.Background(), &shape.Source{Name: "orders_report", DQL: "SELECT * FROM ORDERS o"}) + require.NoError(t, err) + planned, ok := plan.ResultFrom(res) + require.True(t, ok) + require.True(t, planned.ColumnsDiscovery) + require.NotEmpty(t, planned.Views) + assert.True(t, planned.Views[0].ColumnsDiscovery) +} + +func TestDQLCompiler_Compile_ColumnDiscoveryOffFailsWhenRequired(t *testing.T) { + compiler := New() + _, err := compiler.Compile(context.Background(), &shape.Source{Name: "orders_report", DQL: "SELECT * FROM ORDERS o"}, + shape.WithColumnDiscoveryMode(shape.CompileColumnDiscoveryOff)) + require.Error(t, err) + compileErr, ok := err.(*CompileError) + require.True(t, ok) + require.NotEmpty(t, compileErr.Diagnostics) + assert.Equal(t, dqldiag.CodeColDiscoveryReq, compileErr.Diagnostics[0].Code) +} + +func TestDQLCompiler_Compile_TypeContextValidationWarnsInCompat(t *testing.T) { + compiler := New() + dql := ` +#package('github.com/acme/perf') +SELECT id FROM ORDERS t` + res, err := compiler.Compile(context.Background(), &shape.Source{Name: "orders_report", DQL: dql}, shape.WithTypeContextPackageName("bad/name")) + require.NoError(t, err) + planned, ok := plan.ResultFrom(res) + require.True(t, ok) + require.NotEmpty(t, planned.Diagnostics) + assert.Equal(t, dqldiag.CodeTypeCtxInvalid, planned.Diagnostics[0].Code) + assert.Equal(t, dqlshape.SeverityWarning, planned.Diagnostics[0].Severity) +} + +func TestDQLCompiler_Compile_TypeContextValidationFailsInStrict(t *testing.T) { + compiler := New() + dql := `SELECT id FROM ORDERS t` + _, err := compiler.Compile(context.Background(), &shape.Source{Name: "orders_report", DQL: dql}, + shape.WithCompileProfile(shape.CompileProfileStrict), + shape.WithTypeContextPackageName("bad/name")) + require.Error(t, err) + compileErr, ok := err.(*CompileError) + require.True(t, ok) + require.NotEmpty(t, compileErr.Diagnostics) + assert.Equal(t, dqldiag.CodeTypeCtxInvalid, compileErr.Diagnostics[0].Code) + assert.Equal(t, dqlshape.SeverityError, compileErr.Diagnostics[0].Severity) +} + +func TestDQLCompiler_Compile_SyntaxError_HasLineAndChar(t *testing.T) { + compiler := New() + _, err := compiler.Compile(context.Background(), &shape.Source{Name: "orders_report", DQL: "SELECT id FROM ORDERS WHERE ("}) + require.Error(t, err) + compileErr, ok := err.(*CompileError) + require.True(t, ok) + require.NotEmpty(t, compileErr.Diagnostics) + d := compileErr.Diagnostics[0] + assert.Equal(t, dqldiag.CodeParseSyntax, d.Code) + assert.Equal(t, 1, d.Span.Start.Line) + assert.Equal(t, 29, d.Span.Start.Char) +} + +func TestDQLCompiler_Compile_SyntaxError_RemapsAfterSanitize(t *testing.T) { + compiler := New() + dql := "SELECT id FROM ORDERS t WHERE t.id = $Id AND (" + res, err := compiler.Compile(context.Background(), &shape.Source{Name: "orders_report", DQL: dql}) + var diagnostics []*dqlshape.Diagnostic + if err != nil { + compileErr, ok := err.(*CompileError) + require.True(t, ok) + require.NotEmpty(t, compileErr.Diagnostics) + diagnostics = compileErr.Diagnostics + } else { + planned, ok := plan.ResultFrom(res) + require.True(t, ok) + diagnostics = planned.Diagnostics + } + var d *dqlshape.Diagnostic + for _, item := range diagnostics { + if item != nil && item.Code == dqldiag.CodeParseSyntax { + d = item + break + } + } + if d != nil { + assert.Equal(t, 1, d.Span.Start.Line) + assert.Greater(t, d.Span.Start.Char, 0) + assert.LessOrEqual(t, d.Span.Start.Char, len(dql)) + } +} + +func TestDQLCompiler_Compile_DirectiveOnly_HasLineAndChar(t *testing.T) { + compiler := New() + _, err := compiler.Compile(context.Background(), &shape.Source{Name: "orders_report", DQL: "#package('x')"}) + require.Error(t, err) + compileErr, ok := err.(*CompileError) + require.True(t, ok) + require.NotEmpty(t, compileErr.Diagnostics) + d := compileErr.Diagnostics[0] + assert.Equal(t, dqldiag.CodeParseEmpty, d.Code) + assert.Equal(t, 1, d.Span.Start.Line) + assert.Equal(t, 1, d.Span.Start.Char) +} + +func TestDQLCompiler_Compile_InvalidDirective_HasLineAndChar(t *testing.T) { + compiler := New() + _, err := compiler.Compile(context.Background(), &shape.Source{ + Name: "orders_report", + DQL: "SELECT id FROM ORDERS t\n#import('alias')\nSELECT id FROM ORDERS t", + }) + require.Error(t, err) + compileErr, ok := err.(*CompileError) + require.True(t, ok) + require.NotEmpty(t, compileErr.Diagnostics) + d := compileErr.Diagnostics[0] + assert.Equal(t, dqldiag.CodeDirImport, d.Code) + assert.Equal(t, 2, d.Span.Start.Line) + assert.Equal(t, 1, d.Span.Start.Char) +} + +func TestDQLCompiler_Compile_ExtractsJoinLinks(t *testing.T) { + compiler := New() + dql := "SELECT o.id, i.sku FROM orders o JOIN order_items i ON o.id = i.order_id" + res, err := compiler.Compile(context.Background(), &shape.Source{Name: "orders_report", DQL: dql}) + require.NoError(t, err) + planned, ok := plan.ResultFrom(res) + require.True(t, ok) + root := planned.ViewsByName["o"] + require.NotNil(t, root) + require.Len(t, root.Relations, 1) + assert.Equal(t, "i", root.Relations[0].Ref) + require.Len(t, root.Relations[0].On, 1) + assert.Equal(t, "o.id=i.order_id", root.Relations[0].On[0].Expression) + assert.Equal(t, "id", root.Relations[0].On[0].ParentColumn) + assert.Equal(t, "order_id", root.Relations[0].On[0].RefColumn) + assert.Empty(t, planned.Diagnostics) +} + +func TestDQLCompiler_Compile_JoinDiagnostics(t *testing.T) { + compiler := New() + dql := "SELECT o.id FROM orders o JOIN order_items i ON o.id > i.order_id" + res, err := compiler.Compile(context.Background(), &shape.Source{Name: "orders_report", DQL: dql}) + require.NoError(t, err) + planned, ok := plan.ResultFrom(res) + require.True(t, ok) + require.NotEmpty(t, planned.Diagnostics) + assert.Equal(t, dqldiag.CodeRelUnsupported, planned.Diagnostics[0].Code) +} + +func TestDQLCompiler_Compile_StrictRelationWarningsFail(t *testing.T) { + compiler := New() + dql := "SELECT o.id FROM orders o JOIN order_items i ON o.id > i.order_id" + _, err := compiler.Compile(context.Background(), &shape.Source{Name: "orders_report", DQL: dql}, shape.WithCompileStrict(true)) + require.Error(t, err) + compileErr, ok := err.(*CompileError) + require.True(t, ok) + require.NotEmpty(t, compileErr.Diagnostics) + assert.Equal(t, dqldiag.CodeRelUnsupported, compileErr.Diagnostics[0].Code) +} + +func TestDQLCompiler_Compile_ProfileStrictRelationWarningsFail(t *testing.T) { + compiler := New() + dql := "SELECT o.id FROM orders o JOIN order_items i ON o.id > i.order_id" + _, err := compiler.Compile(context.Background(), &shape.Source{Name: "orders_report", DQL: dql}, shape.WithCompileProfile(shape.CompileProfileStrict)) + require.Error(t, err) + compileErr, ok := err.(*CompileError) + require.True(t, ok) + require.NotEmpty(t, compileErr.Diagnostics) + assert.Equal(t, dqldiag.CodeRelUnsupported, compileErr.Diagnostics[0].Code) +} + +func TestDQLCompiler_Compile_StrictAmbiguousLinkFail(t *testing.T) { + compiler := New() + dql := "SELECT o.id FROM orders o JOIN order_items i ON x.id = y.order_id" + _, err := compiler.Compile(context.Background(), &shape.Source{Name: "orders_report", DQL: dql}, shape.WithCompileStrict(true)) + require.Error(t, err) + compileErr, ok := err.(*CompileError) + require.True(t, ok) + require.NotEmpty(t, compileErr.Diagnostics) + assert.Equal(t, dqldiag.CodeRelAmbiguous, compileErr.Diagnostics[0].Code) +} + +func TestDQLCompiler_Compile_SQLInjectionDiagnostic(t *testing.T) { + compiler := New() + dql := "SELECT id FROM ORDERS t WHERE t.id = $Unsafe.Id" + res, err := compiler.Compile(context.Background(), &shape.Source{Name: "orders_report", DQL: dql}) + require.NoError(t, err) + planned, ok := plan.ResultFrom(res) + require.True(t, ok) + require.NotEmpty(t, planned.Diagnostics) + assert.Equal(t, dqldiag.CodeSQLIRawSelector, planned.Diagnostics[0].Code) + assert.Equal(t, 1, planned.Diagnostics[0].Span.Start.Line) + assert.Greater(t, planned.Diagnostics[0].Span.Start.Char, 1) +} + +func TestDQLCompiler_Compile_SanitizesBindings(t *testing.T) { + compiler := New() + dql := "SELECT id FROM ORDERS t WHERE t.id = $Id" + res, err := compiler.Compile(context.Background(), &shape.Source{Name: "orders_report", DQL: dql}) + require.NoError(t, err) + planned, ok := plan.ResultFrom(res) + require.True(t, ok) + require.NotEmpty(t, planned.Views) + assert.Contains(t, planned.Views[0].SQL, "$criteria.AppendBinding($Unsafe.Id)") +} + +func TestDQLCompiler_Compile_ParameterDerivedView(t *testing.T) { + compiler := New() + dql := ` +#set($_ = $Extra(view/extra_view) /* SELECT code FROM EXTRA e */) +SELECT id FROM ORDERS t` + res, err := compiler.Compile(context.Background(), &shape.Source{Name: "orders_report", DQL: dql}) + require.NoError(t, err) + planned, ok := plan.ResultFrom(res) + require.True(t, ok) + require.Len(t, planned.Views, 2) + extra := planned.ViewsByName["e"] + require.NotNil(t, extra) + assert.Equal(t, "EXTRA", extra.Table) + assert.Contains(t, extra.SQL, "SELECT code FROM EXTRA e") +} + +func TestDQLCompiler_Compile_ParameterDerivedView_Options(t *testing.T) { + compiler := New() + dql := ` +#set($_ = $Extra(view/extra_view).WithURI('/v1/extra').WithConnector('analytics').Cardinality('one') /* SELECT code FROM EXTRA e */) +SELECT id FROM ORDERS t` + res, err := compiler.Compile(context.Background(), &shape.Source{Name: "orders_report", DQL: dql}) + require.NoError(t, err) + planned, ok := plan.ResultFrom(res) + require.True(t, ok) + extra := planned.ViewsByName["e"] + require.NotNil(t, extra) + assert.Equal(t, "/v1/extra", extra.SQLURI) + assert.Equal(t, "analytics", extra.Connector) + assert.Equal(t, "one", extra.Cardinality) +} + +func TestDQLCompiler_Compile_ParameterDerivedView_MissingSQLHint(t *testing.T) { + compiler := New() + dql := ` +#set($_ = $Extra(view/extra_view)) +SELECT id FROM ORDERS t` + res, err := compiler.Compile(context.Background(), &shape.Source{Name: "orders_report", DQL: dql}) + require.NoError(t, err) + planned, ok := plan.ResultFrom(res) + require.True(t, ok) + require.NotEmpty(t, planned.Diagnostics) + assert.Equal(t, dqldiag.CodeViewMissingSQL, planned.Diagnostics[len(planned.Diagnostics)-1].Code) +} + +func TestDQLCompiler_Compile_ParameterDerivedView_InvalidCardinalityDiagnostic(t *testing.T) { + compiler := New() + dql := ` +#set($_ = $Extra(view/extra_view).Cardinality('few') /* SELECT code FROM EXTRA e */) +SELECT id FROM ORDERS t` + res, err := compiler.Compile(context.Background(), &shape.Source{Name: "orders_report", DQL: dql}) + require.NoError(t, err) + planned, ok := plan.ResultFrom(res) + require.True(t, ok) + require.NotEmpty(t, planned.Diagnostics) + assert.Equal(t, dqldiag.CodeViewCardinality, planned.Diagnostics[len(planned.Diagnostics)-1].Code) +} + +func TestDQLCompiler_Compile_StrictSQLInjectionWarningsFail(t *testing.T) { + compiler := New() + dql := "SELECT id FROM ORDERS t WHERE t.id = $Unsafe.Id" + _, err := compiler.Compile(context.Background(), &shape.Source{Name: "orders_report", DQL: dql}, shape.WithCompileStrict(true)) + require.Error(t, err) + compileErr, ok := err.(*CompileError) + require.True(t, ok) + require.NotEmpty(t, compileErr.Diagnostics) + assert.Equal(t, dqldiag.CodeSQLIRawSelector, compileErr.Diagnostics[0].Code) +} + +func TestDQLCompiler_Compile_DMLInsert(t *testing.T) { + compiler := New() + res, err := compiler.Compile(context.Background(), &shape.Source{ + Name: "orders_exec", + DQL: "INSERT INTO ORDERS(id) VALUES (1)", + }) + require.NoError(t, err) + planned, ok := plan.ResultFrom(res) + require.True(t, ok) + require.Len(t, planned.Views, 1) + assert.Equal(t, "ORDERS", planned.Views[0].Table) + assert.Equal(t, "many", planned.Views[0].Cardinality) +} + +func TestDQLCompiler_Compile_DMLServiceMissingArg(t *testing.T) { + compiler := New() + _, err := compiler.Compile(context.Background(), &shape.Source{ + Name: "orders_exec", + DQL: "$sql.Insert($rec)", + }) + require.Error(t, err) + compileErr, ok := err.(*CompileError) + require.True(t, ok) + require.NotEmpty(t, compileErr.Diagnostics) + var target *dqlshape.Diagnostic + for _, item := range compileErr.Diagnostics { + if item != nil && item.Code == dqldiag.CodeDMLServiceArg { + target = item + break + } + } + require.NotNil(t, target) + assert.Equal(t, 1, target.Span.Start.Line) + assert.Equal(t, 1, target.Span.Start.Char) +} + +func TestDQLCompiler_Compile_DMLSyntaxError_HasLineAndChar(t *testing.T) { + compiler := New() + _, err := compiler.Compile(context.Background(), &shape.Source{ + Name: "orders_exec", + DQL: "#package('x')\nINSERT INTO ORDERS(id VALUES (1)", + }) + require.Error(t, err) + compileErr, ok := err.(*CompileError) + require.True(t, ok) + require.NotEmpty(t, compileErr.Diagnostics) + var target *dqlshape.Diagnostic + for _, item := range compileErr.Diagnostics { + if item != nil && item.Code == dqldiag.CodeDMLInsert { + target = item + break + } + } + require.NotNil(t, target) + assert.Equal(t, 2, target.Span.Start.Line) + assert.Equal(t, 1, target.Span.Start.Char) +} + +func TestDQLCompiler_Compile_MixedReadExec_Warning(t *testing.T) { + compiler := New() + res, err := compiler.Compile(context.Background(), &shape.Source{ + Name: "mixed_exec", + DQL: "SELECT id FROM ORDERS\nUPDATE ORDERS SET id = 2", + }) + require.NoError(t, err) + planned, ok := plan.ResultFrom(res) + require.True(t, ok) + require.NotEmpty(t, planned.Diagnostics) + assert.Equal(t, dqldiag.CodeDMLMixed, planned.Diagnostics[len(planned.Diagnostics)-1].Code) +} + +func TestDQLCompiler_Compile_MixedMode_ExecWins(t *testing.T) { + compiler := New() + res, err := compiler.Compile(context.Background(), &shape.Source{ + Name: "mixed_exec", + DQL: "SELECT o.id FROM ORDERS o\nUPDATE ORDERS SET id = 2", + }, shape.WithMixedMode(shape.CompileMixedModeExecWins)) + require.NoError(t, err) + planned, ok := plan.ResultFrom(res) + require.True(t, ok) + require.NotEmpty(t, planned.Views) + assert.Equal(t, "ORDERS", planned.Views[0].Table) + require.NotEmpty(t, planned.Diagnostics) + assert.Equal(t, dqldiag.CodeDMLMixed, planned.Diagnostics[len(planned.Diagnostics)-1].Code) +} + +func TestDQLCompiler_Compile_MixedMode_ReadWins(t *testing.T) { + compiler := New() + res, err := compiler.Compile(context.Background(), &shape.Source{ + Name: "mixed_exec", + DQL: "SELECT o.id FROM ORDERS o\nUPDATE ORDERS SET id = 2", + }, shape.WithMixedMode(shape.CompileMixedModeReadWins)) + require.NoError(t, err) + planned, ok := plan.ResultFrom(res) + require.True(t, ok) + require.NotEmpty(t, planned.Views) + assert.Equal(t, "o", planned.Views[0].Name) + assert.Equal(t, "ORDERS", planned.Views[0].Table) + assert.Contains(t, planned.Views[0].SQL, "SELECT o.id FROM ORDERS o") + assert.NotContains(t, planned.Views[0].SQL, "UPDATE ORDERS") + require.NotEmpty(t, planned.Diagnostics) + assert.Equal(t, dqldiag.CodeDMLMixed, planned.Diagnostics[len(planned.Diagnostics)-1].Code) +} + +func TestDQLCompiler_Compile_MixedMode_ErrorOnMixed(t *testing.T) { + compiler := New() + _, err := compiler.Compile(context.Background(), &shape.Source{ + Name: "mixed_exec", + DQL: "SELECT o.id FROM ORDERS o\nUPDATE ORDERS SET id = 2", + }, shape.WithMixedMode(shape.CompileMixedModeErrorOnMixed)) + require.Error(t, err) + compileErr, ok := err.(*CompileError) + require.True(t, ok) + require.NotEmpty(t, compileErr.Diagnostics) + assert.Equal(t, dqldiag.CodeDMLMixed, compileErr.Diagnostics[0].Code) + assert.Equal(t, dqlshape.SeverityError, compileErr.Diagnostics[0].Severity) +} + +func TestDQLCompiler_Compile_UnknownNonRead_Warn(t *testing.T) { + compiler := New() + res, err := compiler.Compile(context.Background(), &shape.Source{ + Name: "orders_report", + DQL: "$Foo.Bar($x)", + }) + require.NoError(t, err) + planned, ok := plan.ResultFrom(res) + require.True(t, ok) + require.NotEmpty(t, planned.Diagnostics) + var found *dqlshape.Diagnostic + for _, item := range planned.Diagnostics { + if item != nil && item.Code == dqldiag.CodeParseUnknownNonRead { + found = item + break + } + } + require.NotNil(t, found) + assert.Equal(t, dqlshape.SeverityWarning, found.Severity) + require.NotEmpty(t, planned.Views) +} + +func TestDQLCompiler_Compile_UnknownNonRead_ErrorMode(t *testing.T) { + compiler := New() + _, err := compiler.Compile(context.Background(), &shape.Source{ + Name: "orders_report", + DQL: "$Foo.Bar($x)", + }, shape.WithUnknownNonReadMode(shape.CompileUnknownNonReadError)) + require.Error(t, err) + compileErr, ok := err.(*CompileError) + require.True(t, ok) + require.NotEmpty(t, compileErr.Diagnostics) + var found *dqlshape.Diagnostic + for _, item := range compileErr.Diagnostics { + if item != nil && item.Code == dqldiag.CodeParseUnknownNonRead { + found = item + break + } + } + require.NotNil(t, found) + assert.Equal(t, dqlshape.SeverityError, found.Severity) +} + +func TestResolveGeneratedCompanionDQL(t *testing.T) { + tempDir := t.TempDir() + dqlPath := filepath.Join(tempDir, "platform", "sitelist", "patch.dql") + require.NoError(t, os.MkdirAll(filepath.Dir(dqlPath), 0o755)) + require.NoError(t, os.MkdirAll(filepath.Join(filepath.Dir(dqlPath), "gen"), 0o755)) + generatedPath := filepath.Join(filepath.Dir(dqlPath), "gen", "patch.sql") + require.NoError(t, os.WriteFile(generatedPath, []byte("SELECT id FROM SITE_LIST sl"), 0o644)) + source := &shape.Source{ + Path: dqlPath, + DQL: `/* {"Type":"sitelist/patch.Handler"} */`, + } + actual := resolveGeneratedCompanionDQL(source) + require.Contains(t, actual, "SELECT id FROM SITE_LIST") +} + +func TestDQLCompiler_Compile_UnknownNonRead_UsesGeneratedCompanion(t *testing.T) { + tempDir := t.TempDir() + dqlPath := filepath.Join(tempDir, "platform", "adorder", "patch.dql") + require.NoError(t, os.MkdirAll(filepath.Join(filepath.Dir(dqlPath), "gen", "adorder"), 0o755)) + require.NoError(t, os.WriteFile(filepath.Join(filepath.Dir(dqlPath), "gen", "adorder", "patch.dql"), []byte("SELECT o.id FROM ORDERS o JOIN ORDER_ITEM i ON i.ORDER_ID = o.ID"), 0o644)) + source := &shape.Source{ + Name: "patch", + Path: dqlPath, + DQL: `/* {"Type":"adorder/patch.Handler"} */`, + } + + compiler := New() + res, err := compiler.Compile(context.Background(), source) + require.NoError(t, err) + planned, ok := plan.ResultFrom(res) + require.True(t, ok) + require.NotNil(t, planned.ViewsByName["o"]) + require.NotNil(t, planned.ViewsByName["i"]) + var hasUnknownNonRead bool + for _, diag := range planned.Diagnostics { + if diag != nil && diag.Code == dqldiag.CodeParseUnknownNonRead { + hasUnknownNonRead = true + break + } + } + assert.False(t, hasUnknownNonRead) +} + +func TestResolveLegacyRouteViews(t *testing.T) { + tempDir := t.TempDir() + sourcePath := filepath.Join(tempDir, "dql", "platform", "campaign", "patch.dql") + require.NoError(t, os.MkdirAll(filepath.Dir(sourcePath), 0o755)) + require.NoError(t, os.WriteFile(sourcePath, []byte(`/* {"Connector":"ci_ads"} */`), 0o644)) + + routeDir := filepath.Join(tempDir, "repo", "dev", "Datly", "routes", "platform", "campaign", "patch") + require.NoError(t, os.MkdirAll(routeDir, 0o755)) + require.NoError(t, os.WriteFile(filepath.Join(routeDir, "patch.sql"), []byte(`SELECT 1`), 0o644)) + require.NoError(t, os.WriteFile(filepath.Join(routeDir, "CurCampaign.sql"), []byte(`SELECT * FROM CI_CAMPAIGN`), 0o644)) + + views := resolveLegacyRouteViews(&shape.Source{Path: sourcePath, DQL: `/* {"Connector":"ci_ads"} */`}) + require.Len(t, views, 2) + assert.Equal(t, "patch", views[0].Name) + assert.Equal(t, "", views[0].Table) + assert.Equal(t, "patch/patch.sql", views[0].SQLURI) + assert.Equal(t, "CurCampaign", views[1].Name) + assert.Equal(t, "CI_CAMPAIGN", views[1].Table) + assert.Equal(t, "ci_ads", views[1].Connector) +} + +func TestResolveLegacyRouteViews_TypeStemSubfolder(t *testing.T) { + tempDir := t.TempDir() + sourcePath := filepath.Join(tempDir, "dql", "platform", "campaign", "post.dql") + require.NoError(t, os.MkdirAll(filepath.Dir(sourcePath), 0o755)) + require.NoError(t, os.WriteFile(sourcePath, []byte(`/* {"Type":"campaign/patch.Handler","Connector":"ci_ads"} */`), 0o644)) + + routeDir := filepath.Join(tempDir, "repo", "dev", "Datly", "routes", "platform", "campaign", "patch", "post") + require.NoError(t, os.MkdirAll(routeDir, 0o755)) + require.NoError(t, os.WriteFile(filepath.Join(routeDir, "post.sql"), []byte(`SELECT 1`), 0o644)) + require.NoError(t, os.WriteFile(filepath.Join(routeDir, "CurCampaign.sql"), []byte(`SELECT * FROM CI_CAMPAIGN`), 0o644)) + + views := resolveLegacyRouteViews(&shape.Source{Path: sourcePath, DQL: `/* {"Type":"campaign/patch.Handler","Connector":"ci_ads"} */`}) + require.Len(t, views, 2) + assert.Equal(t, "post", views[0].Name) + assert.Equal(t, "CurCampaign", views[1].Name) + assert.Equal(t, "post/CurCampaign.sql", views[1].SQLURI) +} + +func TestDQLCompiler_Compile_HandlerNop_NoSQLiEscalation(t *testing.T) { + compiler := New() + res, err := compiler.Compile(context.Background(), &shape.Source{ + Name: "handler_nop", + DQL: "$Nop($Unsafe.Id)", + }, shape.WithCompileStrict(true)) + require.NoError(t, err) + planned, ok := plan.ResultFrom(res) + require.True(t, ok) + for _, item := range planned.Diagnostics { + if item == nil { + continue + } + assert.NotEqual(t, dqldiag.CodeSQLIRawSelector, item.Code) + } +} + +func TestDQLCompiler_Compile_SubqueryJoin_BuildsRelatedViewsAndConnectorHints(t *testing.T) { + compiler := New() + dql := ` +#set($_ = $Jwt(header/Authorization).WithCodec(JwtClaim).WithStatusCode(401)) +SELECT session.*, +use_connector(session, system), +use_connector(attribute, system) +FROM (SELECT * FROM session WHERE user_id = $Jwt.UserID) session +JOIN (SELECT * FROM session/attributes) attribute ON attribute.user_id = session.user_id +` + res, err := compiler.Compile(context.Background(), &shape.Source{Name: "system/session", DQL: dql}) + require.NoError(t, err) + planned, ok := plan.ResultFrom(res) + require.True(t, ok) + root := planned.ViewsByName["session"] + require.NotNil(t, root) + assert.Equal(t, "system", root.Connector) + related := planned.ViewsByName["attribute"] + require.NotNil(t, related) + assert.Equal(t, "session/attributes", related.Table) + assert.Equal(t, "system", related.Connector) +} + +func TestDQLCompiler_Compile_GeneratedHandler_NoBodyInput_DoesNotLoadLegacyContractStates(t *testing.T) { + tempDir := t.TempDir() + genPath := filepath.Join(tempDir, "dql", "system", "upload", "gen", "upload", "delete.dql") + require.NoError(t, os.MkdirAll(filepath.Dir(genPath), 0o755)) + require.NoError(t, os.WriteFile(genPath, []byte(`/* {"Method":"DELETE","URI":"/v1/api/system/upload"} */`), 0o644)) + + legacySQLPath := filepath.Join(tempDir, "dql", "system", "upload", "delete.sql") + require.NoError(t, os.MkdirAll(filepath.Dir(legacySQLPath), 0o755)) + require.NoError(t, os.WriteFile(legacySQLPath, []byte(`/* {"Type":"upload/delete.Handler","Connector":"system"} */`), 0o644)) + + routesDir := filepath.Join(tempDir, "repo", "dev", "Datly", "routes", "system", "upload") + require.NoError(t, os.MkdirAll(filepath.Join(routesDir, "delete"), 0o755)) + routeYAML := `Resource: + Parameters: + - Name: Method + In: + Kind: http_request + Name: method + - Name: UploadId + In: + Kind: query + Name: uploadId + Views: + - Name: delete + Mode: SQLExec + Connector: + Ref: system + Template: + SourceURL: delete/delete.sql +` + require.NoError(t, os.WriteFile(filepath.Join(routesDir, "delete.yaml"), []byte(routeYAML), 0o644)) + require.NoError(t, os.WriteFile(filepath.Join(routesDir, "delete", "delete.sql"), []byte(`$Nop($Unsafe.UploadId)`), 0o644)) + + compiler := New() + res, err := compiler.Compile(context.Background(), &shape.Source{Name: "delete", Path: genPath, DQL: `/* {"Method":"DELETE","URI":"/v1/api/system/upload"} */`}) + require.NoError(t, err) + planned, ok := plan.ResultFrom(res) + require.True(t, ok) + + require.NotEmpty(t, planned.Views) + assert.Equal(t, "delete", planned.Views[0].Name) + assert.Equal(t, "SQLExec", planned.Views[0].Mode) + assert.Equal(t, "system", planned.Views[0].Connector) + + assert.Empty(t, planned.States) +} + +func TestDQLCompiler_Compile_HandlerLegacyTypes_NotLoadedFromLegacyRouteYAML(t *testing.T) { + tempDir := t.TempDir() + sourcePath := filepath.Join(tempDir, "dql", "platform", "campaign", "post.dql") + require.NoError(t, os.MkdirAll(filepath.Dir(sourcePath), 0o755)) + require.NoError(t, os.WriteFile(sourcePath, []byte(`/* {"URI":"/v1/api/platform/campaign","Method":"POST","Type":"campaign/patch.Handler"} */`), 0o644)) + + rootRouteDir := filepath.Join(tempDir, "repo", "dev", "Datly", "routes", "platform", "campaign", "patch") + require.NoError(t, os.MkdirAll(filepath.Join(rootRouteDir, "post"), 0o755)) + require.NoError(t, os.WriteFile(filepath.Join(rootRouteDir, "post.yaml"), []byte(`Resource: + Parameters: + - Name: Auth + In: + Kind: component + Name: GET:/v1/api/platform/acl/auth + Views: + - Name: post + Mode: SQLExec + Connector: + Ref: ci_ads + Template: + SourceURL: post/post.sql + Types: + - Name: Input + DataType: "*Input" + Package: campaign/patch + ModulePath: github.vianttech.com/viant/platform/pkg/platform/campaign/patch + - Name: Handler + DataType: "*Handler" + Package: campaign/patch + ModulePath: github.vianttech.com/viant/platform/pkg/platform/campaign/patch +`), 0o644)) + require.NoError(t, os.WriteFile(filepath.Join(rootRouteDir, "post", "post.sql"), []byte(`$Nop($Unsafe.Id)`), 0o644)) + + componentRouteDir := filepath.Join(tempDir, "repo", "dev", "Datly", "routes", "platform", "acl", "auth") + require.NoError(t, os.MkdirAll(componentRouteDir, 0o755)) + require.NoError(t, os.WriteFile(filepath.Join(componentRouteDir, "auth.yaml"), []byte(`Resource: + Types: + - Name: Input + DataType: "*Input" + Package: acl/auth + ModulePath: github.vianttech.com/viant/platform/pkg/platform/acl/auth + - Name: Handler + DataType: "*Handler" + Package: acl/auth + ModulePath: github.vianttech.com/viant/platform/pkg/platform/acl/auth +`), 0o644)) + + compiler := New() + res, err := compiler.Compile(context.Background(), &shape.Source{ + Name: "post", + Path: sourcePath, + DQL: `/* {"URI":"/v1/api/platform/campaign","Method":"POST","Type":"campaign/patch.Handler"} */`, + }) + require.NoError(t, err) + planned, ok := plan.ResultFrom(res) + require.True(t, ok) + + assert.Empty(t, planned.Types) +} + +func TestDQLCompiler_Compile_CustomPathLayout_NoLegacyHandlerFallback(t *testing.T) { + tempDir := t.TempDir() + sourcePath := filepath.Join(tempDir, "sqlsrc", "platform", "campaign", "post.dql") + require.NoError(t, os.MkdirAll(filepath.Dir(sourcePath), 0o755)) + require.NoError(t, os.WriteFile(sourcePath, []byte(`/* {"URI":"/v1/api/platform/campaign","Method":"POST","Type":"campaign/patch.Handler","Connector":"ci_ads"} */`), 0o644)) + + routesDir := filepath.Join(tempDir, "config", "routes", "platform", "campaign", "patch") + require.NoError(t, os.MkdirAll(filepath.Join(routesDir, "post"), 0o755)) + require.NoError(t, os.WriteFile(filepath.Join(routesDir, "post.yaml"), []byte(`Resource: + Views: + - Name: post + Mode: SQLExec + Connector: + Ref: ci_ads + Template: + SourceURL: post/post.sql +`), 0o644)) + require.NoError(t, os.WriteFile(filepath.Join(routesDir, "post", "post.sql"), []byte(`$Nop($Unsafe.Id)`), 0o644)) + + compiler := New() + res, err := compiler.Compile(context.Background(), &shape.Source{ + Name: "post", + Path: sourcePath, + DQL: `/* {"URI":"/v1/api/platform/campaign","Method":"POST","Type":"campaign/patch.Handler","Connector":"ci_ads"} */`, + }, shape.WithDQLPathMarker("sqlsrc"), shape.WithRoutesRelativePath("config/routes")) + require.NoError(t, err) + planned, ok := plan.ResultFrom(res) + require.True(t, ok) + require.NotEmpty(t, planned.Views) + assert.Equal(t, "post", planned.Views[0].Name) + assert.NotContains(t, planned.Views[0].SQL, "$Nop(") +} diff --git a/repository/shape/compile/component_types.go b/repository/shape/compile/component_types.go new file mode 100644 index 000000000..c7c8f22db --- /dev/null +++ b/repository/shape/compile/component_types.go @@ -0,0 +1,496 @@ +package compile + +import ( + "os" + "path/filepath" + "sort" + "strings" + + "github.com/viant/datly/repository/shape" + dqldiag "github.com/viant/datly/repository/shape/dql/diag" + dqlshape "github.com/viant/datly/repository/shape/dql/shape" + "github.com/viant/datly/repository/shape/plan" + "github.com/viant/datly/view/state" + "gopkg.in/yaml.v3" +) + +type componentVisitState int + +const ( + componentVisitIdle componentVisitState = iota + componentVisitActive + componentVisitDone +) + +func appendComponentTypes(source *shape.Source, result *plan.Result) []*dqlshape.Diagnostic { + return appendComponentTypesWithLayout(source, result, defaultCompilePathLayout()) +} + +func appendComponentTypesWithLayout(source *shape.Source, result *plan.Result, layout compilePathLayout) []*dqlshape.Diagnostic { + if source == nil || result == nil { + return nil + } + _, routesRoot, dqlRoot, ok := sourceRootsWithLayout(source.Path, layout) + if !ok { + return nil + } + sourceNamespace, _ := dqlToRouteNamespaceWithLayout(source.Path, layout) + collector := &componentCollector{ + routesRoot: routesRoot, + visited: map[string]componentVisitState{}, + outputByRoute: map[string]string{}, + typesByName: map[string]*plan.Type{}, + payloadCache: map[string]routePayloadLookup{}, + reportedDiag: map[string]bool{}, + } + if strings.TrimSpace(sourceNamespace) != "" { + collector.collect(sourceNamespace, relationSpan(source.DQL, 0), false) + } + + for _, stateItem := range result.States { + if stateItem == nil || state.Kind(strings.ToLower(stateItem.KindString())) != state.KindComponent { + continue + } + ref := stateItem.InName() + if ref == "" { + continue + } + namespace := resolveComponentNamespaceWithNamespace(ref, source.Path, dqlRoot, sourceNamespace) + if namespace == "" { + collector.diags = append(collector.diags, &dqlshape.Diagnostic{ + Code: dqldiag.CodeCompRefInvalid, + Severity: dqlshape.SeverityWarning, + Message: "invalid component reference: " + ref, + Hint: "use ../component/ref or GET:/v1/api/... route reference", + Span: componentRefSpan(source.DQL, ref), + }) + continue + } + outputType, ok := collector.collect(namespace, componentRefSpan(source.DQL, ref), true) + if ok && strings.TrimSpace(outputType) != "" { + if stateItem.Schema == nil { + stateItem.Schema = &state.Schema{} + } + if strings.TrimSpace(stateItem.Schema.DataType) == "" { + stateItem.Schema.DataType = strings.TrimSpace(outputType) + } + } + } + + sort.Strings(collector.typeOrder) + existing := map[string]bool{} + reportedCollision := map[string]bool{} + for _, item := range result.Types { + if item == nil || strings.TrimSpace(item.Name) == "" { + continue + } + existing[strings.ToLower(strings.TrimSpace(item.Name))] = true + } + for _, name := range collector.typeOrder { + keyName := strings.ToLower(strings.TrimSpace(name)) + if existing[keyName] { + if !reportedCollision[keyName] { + collector.diags = append(collector.diags, &dqlshape.Diagnostic{ + Code: dqldiag.CodeCompTypeCollision, + Severity: dqlshape.SeverityWarning, + Message: "component type skipped due to existing type name: " + strings.TrimSpace(name), + Hint: "rename colliding type or keep route type as canonical source", + Span: relationSpan(source.DQL, 0), + }) + reportedCollision[keyName] = true + } + continue + } + item := collector.typesByName[name] + result.Types = append(result.Types, item) + existing[keyName] = true + } + return collector.diags +} + +type componentCollector struct { + routesRoot string + visited map[string]componentVisitState + outputByRoute map[string]string + // typesByName provides O(1) dedup; typeOrder tracks insertion sequence + // so the final list can be sorted once rather than extracted from the map. + typesByName map[string]*plan.Type + typeOrder []string + payloadCache map[string]routePayloadLookup + reportedDiag map[string]bool + diags []*dqlshape.Diagnostic +} + +type routePayloadLookup struct { + payload *routePayload + found bool + malformed bool + malformedAt string + detail string +} + +func (c *componentCollector) collect(namespace string, span dqlshape.Span, required bool) (string, bool) { + key := strings.ToLower(strings.TrimSpace(namespace)) + if key == "" { + return "", false + } + switch c.visited[key] { + case componentVisitDone: + return c.outputByRoute[key], true + case componentVisitActive: + c.diags = append(c.diags, &dqlshape.Diagnostic{ + Code: dqldiag.CodeCompCycle, + Severity: dqlshape.SeverityWarning, + Message: "component reference cycle detected at " + namespace, + Hint: "break cyclic component references", + Span: span, + }) + return "", false + } + c.visited[key] = componentVisitActive + + payload, ok := c.loadRoutePayload(namespace, span) + if !ok { + c.visited[key] = componentVisitDone + if required && !c.hasReported("missing:"+key) { + c.reportedDiag["missing:"+key] = true + c.diags = append(c.diags, &dqlshape.Diagnostic{ + Code: dqldiag.CodeCompRouteMissing, + Severity: dqlshape.SeverityWarning, + Message: "component route YAML not found: " + namespace, + Hint: "ensure matching route exists under repo/dev/Datly/routes", + Span: span, + }) + } + return "", false + } + + for _, item := range payload.Resource.Types { + name := strings.TrimSpace(item.Name) + if name == "" { + continue + } + keyName := strings.ToLower(name) + if _, exists := c.typesByName[keyName]; exists { + continue + } + c.typeOrder = append(c.typeOrder, keyName) + c.typesByName[keyName] = &plan.Type{ + Name: name, + Alias: strings.TrimSpace(item.Alias), + DataType: strings.TrimSpace(item.DataType), + Cardinality: strings.TrimSpace(item.Cardinality), + Package: strings.TrimSpace(item.Package), + ModulePath: strings.TrimSpace(item.ModulePath), + } + } + + outputType := routeOutputType(payload) + c.outputByRoute[key] = outputType + + for _, param := range payload.Resource.Parameters { + if !strings.EqualFold(strings.TrimSpace(param.In.Kind), string(state.KindComponent)) { + continue + } + nextNS := resolveComponentNamespaceFromRoute(strings.TrimSpace(param.In.Name), namespace) + if nextNS == "" { + c.diags = append(c.diags, &dqlshape.Diagnostic{ + Code: dqldiag.CodeCompRefInvalid, + Severity: dqlshape.SeverityWarning, + Message: "invalid nested component reference: " + strings.TrimSpace(param.In.Name), + Hint: "use ../component/ref or GET:/v1/api/... route reference", + Span: span, + }) + continue + } + c.collect(nextNS, span, true) + } + + c.visited[key] = componentVisitDone + return outputType, true +} + +func sourceRootsWithLayout(sourcePath string, layout compilePathLayout) (platformRoot, routesRoot, dqlRoot string, ok bool) { + path := filepath.Clean(strings.TrimSpace(sourcePath)) + if path == "" { + return "", "", "", false + } + normalized := filepath.ToSlash(path) + marker := layout.dqlMarker + if marker == "" { + marker = defaultCompilePathLayout().dqlMarker + } + idx := strings.Index(normalized, marker) + if idx == -1 { + return "", "", "", false + } + platformRoot = path[:idx] + dqlRoot = filepath.Join(platformRoot, filepath.FromSlash(strings.Trim(marker, "/"))) + routesRoot = joinRelativePath(platformRoot, layout.routesRelative) + return platformRoot, routesRoot, dqlRoot, true +} + +func dqlToRouteNamespace(sourcePath string) (string, bool) { + return dqlToRouteNamespaceWithLayout(sourcePath, defaultCompilePathLayout()) +} + +func dqlToRouteNamespaceWithLayout(sourcePath string, layout compilePathLayout) (string, bool) { + path := filepath.Clean(strings.TrimSpace(sourcePath)) + if path == "" { + return "", false + } + normalized := filepath.ToSlash(path) + marker := layout.dqlMarker + if marker == "" { + marker = defaultCompilePathLayout().dqlMarker + } + idx := strings.Index(normalized, marker) + if idx == -1 { + return "", false + } + relative := strings.TrimPrefix(normalized[idx+len(marker):], "/") + if relative == "" { + return "", false + } + return strings.Trim(strings.TrimSuffix(relative, filepath.Ext(relative)), "/"), true +} + +func resolveComponentNamespace(ref, sourcePath, dqlRoot string) string { + ref = strings.TrimSpace(ref) + ref = strings.TrimPrefix(ref, "GET:") + ref = strings.TrimPrefix(ref, "POST:") + ref = strings.TrimPrefix(ref, "PUT:") + ref = strings.TrimPrefix(ref, "PATCH:") + ref = strings.TrimPrefix(ref, "DELETE:") + ref = strings.TrimPrefix(ref, "OPTIONS:") + ref = strings.TrimSpace(ref) + if strings.HasPrefix(ref, "/v1/api/") { + return strings.Trim(strings.TrimPrefix(ref, "/v1/api/"), "/") + } + if strings.HasPrefix(ref, "v1/api/") { + return strings.Trim(strings.TrimPrefix(ref, "v1/api/"), "/") + } + if strings.HasPrefix(ref, "/") { + return strings.Trim(ref, "/") + } + if dqlRoot == "" || strings.TrimSpace(sourcePath) == "" { + return "" + } + base := filepath.Dir(filepath.Clean(sourcePath)) + target := filepath.Clean(filepath.Join(base, ref)) + rel, err := filepath.Rel(dqlRoot, target) + if err != nil { + return "" + } + rel = filepath.ToSlash(rel) + rel = strings.TrimSuffix(rel, filepath.Ext(rel)) + return strings.Trim(rel, "/") +} + +func resolveComponentNamespaceWithNamespace(ref, sourcePath, dqlRoot, sourceNamespace string) string { + if namespace := resolveComponentNamespace(ref, sourcePath, dqlRoot); namespace != "" { + return namespace + } + return resolveComponentNamespaceFromRoute(ref, sourceNamespace) +} + +func resolveComponentNamespaceFromRoute(ref, sourceNamespace string) string { + ref = strings.TrimSpace(ref) + if ref == "" { + return "" + } + if namespace := resolveComponentNamespace(ref, "", ""); namespace != "" { + return namespace + } + normalizedBase := strings.Trim(strings.TrimSpace(sourceNamespace), "/") + if normalizedBase == "" { + return "" + } + baseDir := pathDir(normalizedBase) + target := filepath.ToSlash(filepath.Clean(filepath.Join(baseDir, ref))) + target = strings.TrimSuffix(target, filepath.Ext(target)) + return strings.Trim(target, "/") +} + +func pathDir(path string) string { + if path == "" { + return "" + } + parts := strings.Split(strings.Trim(path, "/"), "/") + if len(parts) <= 1 { + return "" + } + return strings.Join(parts[:len(parts)-1], "/") +} + +type routePayload struct { + Resource struct { + Types []struct { + Name string `yaml:"Name"` + Alias string `yaml:"Alias"` + DataType string `yaml:"DataType"` + Cardinality string `yaml:"Cardinality"` + Package string `yaml:"Package"` + ModulePath string `yaml:"ModulePath"` + } `yaml:"Types"` + Parameters []struct { + Name string `yaml:"Name"` + In struct { + Kind string `yaml:"Kind"` + Name string `yaml:"Name"` + } `yaml:"In"` + Schema struct { + DataType string `yaml:"DataType"` + Name string `yaml:"Name"` + Package string `yaml:"Package"` + Cardinality string `yaml:"Cardinality"` + } `yaml:"Schema"` + } `yaml:"Parameters"` + } `yaml:"Resource"` + Routes []struct { + Handler struct { + OutputType string `yaml:"OutputType"` + } `yaml:"Handler"` + Output struct { + Cardinality string `yaml:"Cardinality"` + Type struct { + Name string `yaml:"Name"` + Package string `yaml:"Package"` + } `yaml:"Type"` + } `yaml:"Output"` + } `yaml:"Routes"` +} + +func readRoutePayload(routesRoot, namespace string) routePayloadLookup { + candidates := routeYAMLCandidates(routesRoot, namespace) + lookup := routePayloadLookup{} + for _, candidate := range candidates { + data, err := os.ReadFile(candidate) + if err != nil { + continue + } + payload := &routePayload{} + if err = yaml.Unmarshal(data, payload); err != nil { + if !lookup.malformed { + lookup.malformed = true + lookup.malformedAt = candidate + lookup.detail = strings.TrimSpace(err.Error()) + } + continue + } + lookup.payload = payload + lookup.found = true + lookup.malformed = false + lookup.malformedAt = "" + lookup.detail = "" + return lookup + } + return lookup +} + +func (c *componentCollector) loadRoutePayload(namespace string, span dqlshape.Span) (*routePayload, bool) { + key := strings.ToLower(strings.TrimSpace(namespace)) + if key == "" { + return nil, false + } + lookup, ok := c.payloadCache[key] + if !ok { + lookup = readRoutePayload(c.routesRoot, namespace) + c.payloadCache[key] = lookup + } + if lookup.malformed && !lookup.found && !c.hasReported("invalid:"+key) { + c.reportedDiag["invalid:"+key] = true + message := "component route YAML malformed: " + namespace + if strings.TrimSpace(lookup.malformedAt) != "" { + message += " (" + lookup.malformedAt + ")" + } + hint := "fix route YAML format" + if strings.TrimSpace(lookup.detail) != "" { + hint += ": " + lookup.detail + } + c.diags = append(c.diags, &dqlshape.Diagnostic{ + Code: dqldiag.CodeCompRouteInvalid, + Severity: dqlshape.SeverityWarning, + Message: message, + Hint: hint, + Span: span, + }) + } + return lookup.payload, lookup.found +} + +func (c *componentCollector) hasReported(key string) bool { + if c == nil || c.reportedDiag == nil { + return false + } + return c.reportedDiag[key] +} + +func routeOutputType(payload *routePayload) string { + if payload == nil { + return "" + } + for _, route := range payload.Routes { + if outputType := strings.TrimSpace(route.Handler.OutputType); outputType != "" { + leaf := outputType + if idx := strings.LastIndex(leaf, "."); idx >= 0 && idx+1 < len(leaf) { + leaf = leaf[idx+1:] + } + leaf = strings.Trim(strings.TrimSpace(leaf), "*") + if leaf != "" { + return "*" + leaf + } + } + if name := strings.TrimSpace(route.Output.Type.Name); name != "" { + name = strings.Trim(name, "*") + if name != "" { + return "*" + name + } + } + } + for _, param := range payload.Resource.Parameters { + if strings.EqualFold(strings.TrimSpace(param.In.Kind), string(state.KindOutput)) { + if dataType := strings.TrimSpace(param.Schema.DataType); dataType != "" { + return dataType + } + if name := strings.TrimSpace(param.Schema.Name); name != "" { + name = strings.Trim(name, "*") + if name != "" { + return "*" + name + } + } + } + } + for _, item := range payload.Resource.Types { + if strings.EqualFold(strings.TrimSpace(item.Name), string(state.KindOutput)) { + if dataType := strings.TrimSpace(item.DataType); dataType != "" { + return dataType + } + return "*Output" + } + } + return "" +} + +func componentRefSpan(raw, ref string) dqlshape.Span { + offset := 0 + ref = strings.TrimSpace(ref) + if ref != "" { + if idx := strings.Index(raw, ref); idx >= 0 { + offset = idx + } + } + return relationSpan(raw, offset) +} + +func routeYAMLCandidates(routesRoot, namespace string) []string { + namespace = strings.Trim(namespace, "/") + if namespace == "" { + return nil + } + leaf := filepath.Base(namespace) + return []string{ + filepath.Join(routesRoot, filepath.FromSlash(namespace)+".yaml"), + filepath.Join(routesRoot, filepath.FromSlash(namespace), leaf+".yaml"), + } +} diff --git a/repository/shape/compile/component_types_test.go b/repository/shape/compile/component_types_test.go new file mode 100644 index 000000000..2cf803b7f --- /dev/null +++ b/repository/shape/compile/component_types_test.go @@ -0,0 +1,205 @@ +package compile + +import ( + "os" + "path/filepath" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/viant/datly/repository/shape" + dqldiag "github.com/viant/datly/repository/shape/dql/diag" + "github.com/viant/datly/repository/shape/plan" + "github.com/viant/datly/view/state" +) + +func TestResolveComponentNamespace(t *testing.T) { + dqlRoot := "/repo/dql" + source := "/repo/dql/platform/tvaffiliatestation/tvaffiliatestation.dql" + assert.Equal(t, "platform/acl/auth", resolveComponentNamespace("../acl/auth", source, dqlRoot)) + assert.Equal(t, "platform/acl/auth", resolveComponentNamespace("GET:/v1/api/platform/acl/auth", source, dqlRoot)) + assert.Equal(t, "platform/acl/auth", resolveComponentNamespace("v1/api/platform/acl/auth", source, dqlRoot)) +} + +func TestDQLToRouteNamespace(t *testing.T) { + ns, ok := dqlToRouteNamespace("/repo/dql/platform/tvaffiliatestation/tvaffiliatestation.dql") + require.True(t, ok) + assert.Equal(t, "platform/tvaffiliatestation/tvaffiliatestation", ns) +} + +func TestSourceRoots_CustomLayout(t *testing.T) { + layout := compilePathLayout{ + dqlMarker: "/sqlsrc/", + routesRelative: "config/routes", + } + platformRoot, routesRoot, dqlRoot, ok := sourceRootsWithLayout("/repo/sqlsrc/platform/agency/agency.dql", layout) + require.True(t, ok) + assert.Equal(t, "/repo", filepath.ToSlash(platformRoot)) + assert.Equal(t, "/repo/config/routes", filepath.ToSlash(routesRoot)) + assert.Equal(t, "/repo/sqlsrc", filepath.ToSlash(dqlRoot)) + + ns, ok := dqlToRouteNamespaceWithLayout("/repo/sqlsrc/platform/agency/agency.dql", layout) + require.True(t, ok) + assert.Equal(t, "platform/agency/agency", ns) +} + +func TestAppendComponentTypes(t *testing.T) { + temp := t.TempDir() + dqlDir := filepath.Join(temp, "dql", "platform", "tvaffiliatestation") + routesDir := filepath.Join(temp, "repo", "dev", "Datly", "routes", "platform", "acl") + require.NoError(t, os.MkdirAll(dqlDir, 0o755)) + require.NoError(t, os.MkdirAll(filepath.Join(routesDir, "auth"), 0o755)) + require.NoError(t, os.MkdirAll(routesDir, 0o755)) + sourcePath := filepath.Join(dqlDir, "tvaffiliatestation.dql") + require.NoError(t, os.WriteFile(sourcePath, []byte("SELECT 1"), 0o644)) + + authYAML := `Resource: + Types: + - Name: Input + DataType: "*Input" + Package: acl/auth + ModulePath: github.vianttech.com/viant/platform/pkg/platform/acl/auth + Parameters: + - In: + Kind: component + Name: GET:/v1/api/platform/acl/user +Routes: + - Handler: + OutputType: acl/auth.Output +` + userYAML := `Resource: + Types: + - Name: UserView + DataType: "struct{Id int;}" + Package: acl + ModulePath: github.vianttech.com/viant/platform/pkg/platform/acl +` + require.NoError(t, os.WriteFile(filepath.Join(routesDir, "auth", "auth.yaml"), []byte(authYAML), 0o644)) + require.NoError(t, os.WriteFile(filepath.Join(routesDir, "user.yaml"), []byte(userYAML), 0o644)) + + result := &plan.Result{ + States: []*plan.State{ + {Parameter: state.Parameter{Name: "Auth", In: &state.Location{Kind: state.KindComponent, Name: "../acl/auth"}}}, + }, + } + appendComponentTypes(&shape.Source{Path: sourcePath, DQL: "#set($Auth = $component<../acl/auth>())"}, result) + require.Len(t, result.Types, 2) + names := map[string]bool{} + for _, item := range result.Types { + names[item.Name] = true + } + assert.True(t, names["Input"]) + assert.True(t, names["UserView"]) + assert.Equal(t, "*Output", result.States[0].Schema.DataType) +} + +func TestAppendComponentTypes_MissingComponentRoute(t *testing.T) { + temp := t.TempDir() + dqlDir := filepath.Join(temp, "dql", "platform", "sample") + require.NoError(t, os.MkdirAll(dqlDir, 0o755)) + sourcePath := filepath.Join(dqlDir, "sample.dql") + dql := "#set($Auth = $component<../acl/missing>())\nSELECT 1" + require.NoError(t, os.WriteFile(sourcePath, []byte(dql), 0o644)) + result := &plan.Result{ + States: []*plan.State{{Parameter: state.Parameter{Name: "Auth", In: &state.Location{Kind: state.KindComponent, Name: "../acl/missing"}}}}, + } + diags := appendComponentTypes(&shape.Source{Path: sourcePath, DQL: dql}, result) + require.NotEmpty(t, diags) + assert.Equal(t, "DQL-COMP-ROUTE-MISSING", diags[0].Code) + assert.GreaterOrEqual(t, diags[0].Span.Start.Line, 1) + assert.GreaterOrEqual(t, diags[0].Span.Start.Char, 1) +} + +func TestAppendComponentTypes_TypeCollisionEmitsDiagnostic(t *testing.T) { + temp := t.TempDir() + dqlDir := filepath.Join(temp, "dql", "platform", "tvaffiliatestation") + routesDir := filepath.Join(temp, "repo", "dev", "Datly", "routes", "platform", "acl") + require.NoError(t, os.MkdirAll(dqlDir, 0o755)) + require.NoError(t, os.MkdirAll(filepath.Join(routesDir, "auth"), 0o755)) + sourcePath := filepath.Join(dqlDir, "tvaffiliatestation.dql") + require.NoError(t, os.WriteFile(sourcePath, []byte("SELECT 1"), 0o644)) + + authYAML := `Resource: + Types: + - Name: Input + DataType: "*Input" + Package: acl/auth + ModulePath: github.vianttech.com/viant/platform/pkg/platform/acl/auth +` + require.NoError(t, os.WriteFile(filepath.Join(routesDir, "auth", "auth.yaml"), []byte(authYAML), 0o644)) + + result := &plan.Result{ + States: []*plan.State{ + {Parameter: state.Parameter{Name: "Auth", In: &state.Location{Kind: state.KindComponent, Name: "../acl/auth"}}}, + }, + Types: []*plan.Type{ + { + Name: "Input", + DataType: "*Input", + Package: "campaign/patch", + ModulePath: "github.vianttech.com/viant/platform/pkg/platform/campaign/patch", + }, + }, + } + diags := appendComponentTypes(&shape.Source{Path: sourcePath, DQL: "#set($Auth = $component<../acl/auth>())"}, result) + require.NotEmpty(t, diags) + var found bool + for _, item := range diags { + if item != nil && item.Code == dqldiag.CodeCompTypeCollision { + found = true + break + } + } + assert.True(t, found) + require.Len(t, result.Types, 1) + assert.Equal(t, "campaign/patch", result.Types[0].Package) +} + +func TestAppendComponentTypes_InvalidRouteYAMLEmitsDiagnostic(t *testing.T) { + temp := t.TempDir() + dqlDir := filepath.Join(temp, "dql", "platform", "sample") + routesDir := filepath.Join(temp, "repo", "dev", "Datly", "routes", "platform", "acl") + require.NoError(t, os.MkdirAll(dqlDir, 0o755)) + require.NoError(t, os.MkdirAll(filepath.Join(routesDir, "auth"), 0o755)) + + sourcePath := filepath.Join(dqlDir, "sample.dql") + dql := "#set($Auth = $component<../acl/auth>())\nSELECT 1" + require.NoError(t, os.WriteFile(sourcePath, []byte(dql), 0o644)) + require.NoError(t, os.WriteFile(filepath.Join(routesDir, "auth", "auth.yaml"), []byte("Resource:\n Types: ["), 0o644)) + + result := &plan.Result{ + States: []*plan.State{{Parameter: state.Parameter{Name: "Auth", In: &state.Location{Kind: state.KindComponent, Name: "../acl/auth"}}}}, + } + diags := appendComponentTypes(&shape.Source{Path: sourcePath, DQL: dql}, result) + require.NotEmpty(t, diags) + assert.Equal(t, dqldiag.CodeCompRouteInvalid, diags[0].Code) +} + +func TestAppendComponentTypes_InvalidRouteYAMLDedupedForRepeatedStates(t *testing.T) { + temp := t.TempDir() + dqlDir := filepath.Join(temp, "dql", "platform", "sample") + routesDir := filepath.Join(temp, "repo", "dev", "Datly", "routes", "platform", "acl") + require.NoError(t, os.MkdirAll(dqlDir, 0o755)) + require.NoError(t, os.MkdirAll(filepath.Join(routesDir, "auth"), 0o755)) + + sourcePath := filepath.Join(dqlDir, "sample.dql") + dql := "#set($Auth1 = $component<../acl/auth>())\n#set($Auth2 = $component<../acl/auth>())\nSELECT 1" + require.NoError(t, os.WriteFile(sourcePath, []byte(dql), 0o644)) + require.NoError(t, os.WriteFile(filepath.Join(routesDir, "auth", "auth.yaml"), []byte("Resource:\n Types: ["), 0o644)) + + result := &plan.Result{ + States: []*plan.State{ + {Parameter: state.Parameter{Name: "Auth1", In: &state.Location{Kind: state.KindComponent, Name: "../acl/auth"}}}, + {Parameter: state.Parameter{Name: "Auth2", In: &state.Location{Kind: state.KindComponent, Name: "../acl/auth"}}}, + }, + } + diags := appendComponentTypes(&shape.Source{Path: sourcePath, DQL: dql}, result) + require.NotEmpty(t, diags) + invalidCount := 0 + for _, item := range diags { + if item != nil && item.Code == dqldiag.CodeCompRouteInvalid { + invalidCount++ + } + } + assert.Equal(t, 1, invalidCount) +} diff --git a/repository/shape/compile/dml/compiler.go b/repository/shape/compile/dml/compiler.go new file mode 100644 index 000000000..8b6838fda --- /dev/null +++ b/repository/shape/compile/dml/compiler.go @@ -0,0 +1,13 @@ +package dml + +import ( + "github.com/viant/datly/repository/shape/compile/pipeline" + dqlshape "github.com/viant/datly/repository/shape/dql/shape" + dqlstmt "github.com/viant/datly/repository/shape/dql/statement" + "github.com/viant/datly/repository/shape/plan" +) + +// Compile builds an exec-oriented view and validates DML statements. +func Compile(sourceName, sqlText string, statements dqlstmt.Statements) (*plan.View, []*dqlshape.Diagnostic) { + return pipeline.BuildExec(sourceName, sqlText, statements) +} diff --git a/repository/shape/compile/dml/compiler_test.go b/repository/shape/compile/dml/compiler_test.go new file mode 100644 index 000000000..7c1d907f9 --- /dev/null +++ b/repository/shape/compile/dml/compiler_test.go @@ -0,0 +1,25 @@ +package dml + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + dqldiag "github.com/viant/datly/repository/shape/dql/diag" + dqlstmt "github.com/viant/datly/repository/shape/dql/statement" +) + +func TestCompile_Insert(t *testing.T) { + sqlText := "INSERT INTO ORDERS(id) VALUES (1)" + view, diags := Compile("orders_exec", sqlText, dqlstmt.New(sqlText)) + require.NotNil(t, view) + assert.Equal(t, "ORDERS", view.Table) + assert.Empty(t, diags) +} + +func TestCompile_ServiceMissingArg(t *testing.T) { + sqlText := "$sql.Insert($rec)" + _, diags := Compile("orders_exec", sqlText, dqlstmt.New(sqlText)) + require.NotEmpty(t, diags) + assert.Equal(t, dqldiag.CodeDMLServiceArg, diags[0].Code) +} diff --git a/repository/shape/compile/doc.go b/repository/shape/compile/doc.go new file mode 100644 index 000000000..c5a996ba8 --- /dev/null +++ b/repository/shape/compile/doc.go @@ -0,0 +1,2 @@ +// Package compile provides DQL-to-shape compilation. +package compile diff --git a/repository/shape/compile/enrich.go b/repository/shape/compile/enrich.go new file mode 100644 index 000000000..3ea8a277b --- /dev/null +++ b/repository/shape/compile/enrich.go @@ -0,0 +1,316 @@ +package compile + +// enrich.go — per-view enrichment passes applied after DQL compilation. +// Table inference helpers live in enrich_table.go; low-level text scanning +// primitives live in enrich_text.go. + +import ( + "encoding/json" + "path/filepath" + "strings" + + "github.com/viant/datly/repository/shape" + "github.com/viant/datly/repository/shape/compile/pipeline" + dqlshape "github.com/viant/datly/repository/shape/dql/shape" + "github.com/viant/datly/repository/shape/plan" +) + +type ruleSettings struct { + Connector string `json:"Connector"` + Name string `json:"Name"` + Type string `json:"Type"` + Method string `json:"Method"` + URI string `json:"URI"` +} + +type parityEnrichmentContext struct { + source *shape.Source + settings *ruleSettings + baseDir string + module string + sourceName string + joinEmbedRefs map[string]string + joinSubqueryBodies map[string]string +} + +func applySourceParityEnrichment(result *plan.Result, source *shape.Source) { + applySourceParityEnrichmentWithLayout(result, source, defaultCompilePathLayout()) +} + +func applySourceParityEnrichmentWithLayout(result *plan.Result, source *shape.Source, layout compilePathLayout) { + if result == nil || len(result.Views) == 0 { + return + } + ctx := buildParityEnrichmentContext(result, source, layout) + for idx, item := range result.Views { + if item == nil { + continue + } + applyViewDefaults(item, idx == 0, ctx) + applyTableInference(item, ctx) + applyConnectorInference(item, ctx) + applySummaryInference(item, ctx) + } + if source != nil && strings.TrimSpace(source.Path) != "" { + normalizeRootViewName(result, ctx.sourceName) + } +} + +func buildParityEnrichmentContext(result *plan.Result, source *shape.Source, layout compilePathLayout) *parityEnrichmentContext { + ctx := &parityEnrichmentContext{ + source: source, + settings: extractRuleSettings(source, result.Directives), + baseDir: sourceSQLBaseDir(source), + module: sourceModuleWithLayout(source, layout), + sourceName: pipeline.SanitizeName(source.Name), + joinEmbedRefs: map[string]string{}, + joinSubqueryBodies: map[string]string{}, + } + if len(result.Views) == 0 || result.Views[0] == nil { + return ctx + } + sqlForJoinExtract := result.Views[0].SQL + if source != nil && strings.TrimSpace(source.DQL) != "" { + sqlForJoinExtract = source.DQL + } + ctx.joinEmbedRefs = extractJoinEmbedRefs(sqlForJoinExtract) + ctx.joinSubqueryBodies = extractJoinSubqueryBodies(sqlForJoinExtract) + return ctx +} + +func applyViewDefaults(item *plan.View, root bool, ctx *parityEnrichmentContext) { + if item == nil || ctx == nil { + return + } + if item.SQLURI == "" && ctx.baseDir != "" { + item.SQLURI = ctx.baseDir + "/" + item.Name + ".sql" + } + if item.Module == "" { + item.Module = ctx.module + } + if item.SelectorNamespace == "" { + item.SelectorNamespace = defaultSelectorNamespace(item.Name) + } + if item.SchemaType == "" { + item.SchemaType = defaultSchemaType(item.Name, ctx.settings, root) + } +} + +func applyTableInference(item *plan.View, ctx *parityEnrichmentContext) { + if item == nil || ctx == nil { + return + } + if shouldInferTable(item) { + candidateSQL := item.SQL + if strings.TrimSpace(candidateSQL) == "" { + candidateSQL = item.Table + } + if table := inferTableFromSQL(candidateSQL, ctx.source); table != "" { + item.Table = table + } + } + if strings.HasPrefix(strings.TrimSpace(item.Table), "(") || normalizedTemplatePlaceholderTable(strings.TrimSpace(item.Table)) { + if ref, ok := ctx.joinEmbedRefs[item.Name]; ok { + if table := inferTableFromEmbedRef(ctx.source, ref); table != "" { + item.Table = table + } + } + if body, ok := ctx.joinSubqueryBodies[item.Name]; ok { + if table := inferTableFromSQL(body, ctx.source); table != "" { + item.Table = table + } + } + if table := inferTableFromSiblingSQL(item.Name, ctx.source); table != "" { + item.Table = table + } + } +} + +func applyConnectorInference(item *plan.View, ctx *parityEnrichmentContext) { + if item == nil || ctx == nil || item.Connector != "" { + return + } + if ctx.settings != nil && ctx.settings.Connector != "" { + item.Connector = ctx.settings.Connector + } + if item.Connector == "" && ctx.source != nil && strings.TrimSpace(ctx.source.Connector) != "" { + item.Connector = strings.TrimSpace(ctx.source.Connector) + } + if item.Connector == "" { + item.Connector = inferConnector(item, ctx.source) + } +} + +func applySummaryInference(item *plan.View, ctx *parityEnrichmentContext) { + if item == nil || ctx == nil || item.Summary != "" { + return + } + item.Summary = extractSummarySQL(item.SQL) + if item.Summary == "" && ctx.source != nil { + item.Summary = extractSummarySQL(ctx.source.DQL) + } +} + +func extractSummarySQL(sqlText string) string { + sqlText = strings.TrimSpace(sqlText) + if sqlText == "" || !strings.Contains(sqlText, "$View.") { + return "" + } + body, ok := findSummaryJoinBody(sqlText) + if !ok { + return "" + } + return strings.TrimSpace(body) +} + +func extractRuleSettings(source *shape.Source, directives *dqlshape.Directives) *ruleSettings { + if source == nil || strings.TrimSpace(source.DQL) == "" { + return &ruleSettings{} + } + ret := &ruleSettings{} + if rawJSON, ok := extractLeadingRuleHeaderJSON(source.DQL); ok { + _ = json.Unmarshal([]byte(rawJSON), ret) + } + if directives != nil && directives.Route != nil { + if uri := strings.TrimSpace(directives.Route.URI); uri != "" { + ret.URI = uri + } + if len(directives.Route.Methods) > 0 { + ret.Method = strings.Join(directives.Route.Methods, ",") + } + } + return ret +} + +func sourceSQLBaseDir(source *shape.Source) string { + if source == nil { + return "" + } + path := strings.TrimSpace(source.Path) + if path == "" { + return "" + } + base := strings.TrimSpace(filepath.Base(path)) + if base == "" { + return "" + } + stem := strings.TrimSpace(strings.TrimSuffix(base, filepath.Ext(base))) + if stem == "" || stem == "." || stem == string(filepath.Separator) { + return "" + } + return stem +} + +func sourceModuleWithLayout(source *shape.Source, layout compilePathLayout) string { + if source == nil || strings.TrimSpace(source.Path) == "" { + return "" + } + normalized := filepath.ToSlash(source.Path) + marker := layout.dqlMarker + if marker == "" { + marker = defaultCompilePathLayout().dqlMarker + } + idx := strings.Index(normalized, marker) + if idx == -1 { + return "" + } + relative := strings.TrimPrefix(normalized[idx+len(marker):], "/") + dir := strings.TrimSpace(filepath.ToSlash(filepath.Dir(relative))) + if dir == "." || dir == "/" { + return "" + } + return dir +} + +func defaultSelectorNamespace(name string) string { + name = strings.TrimSpace(name) + if name == "" { + return "" + } + var b strings.Builder + for i := 0; i < len(name); i++ { + ch := name[i] + if (ch >= 'a' && ch <= 'z') || (ch >= 'A' && ch <= 'Z') { + b.WriteByte(byte(strings.ToLower(string(ch))[0])) + } + } + value := b.String() + switch { + case len(value) >= 2: + return value[:2] + case len(value) == 1: + return value + default: + return "" + } +} + +func defaultSchemaType(name string, settings *ruleSettings, root bool) string { + if root && settings != nil && strings.TrimSpace(settings.Name) != "" { + return "*" + strings.TrimSpace(settings.Name) + "View" + } + name = strings.TrimSpace(name) + if name == "" { + return "" + } + return "*" + toExportedTypeName(name) + "View" +} + +func toExportedTypeName(name string) string { + name = strings.TrimSpace(name) + if name == "" { + return "" + } + parts := strings.FieldsFunc(name, func(r rune) bool { + return r == '_' || r == '-' || r == ' ' || r == '.' + }) + if len(parts) == 0 { + return "" + } + var b strings.Builder + for _, part := range parts { + part = strings.TrimSpace(part) + if part == "" { + continue + } + b.WriteString(strings.ToUpper(part[:1])) + if len(part) > 1 { + b.WriteString(part[1:]) + } + } + return b.String() +} + +// extractJoinEmbedRefs builds a map of view-alias → embed-path for every +// JOIN(${embed:path}) alias clause found in sqlText. +func extractJoinEmbedRefs(sqlText string) map[string]string { + result := map[string]string{} + if strings.TrimSpace(sqlText) == "" { + return result + } + for _, item := range scanJoinSubqueries(sqlText) { + ref, ok := parseJoinEmbedRef(item.body) + if !ok || ref == "" || item.alias == "" { + continue + } + result[item.alias] = ref + } + return result +} + +// extractJoinSubqueryBodies builds a map of view-alias → subquery-body for +// every JOIN(body) alias clause found in sqlText. +func extractJoinSubqueryBodies(sqlText string) map[string]string { + result := map[string]string{} + if strings.TrimSpace(sqlText) == "" { + return result + } + for _, item := range scanJoinSubqueries(sqlText) { + body := strings.TrimSpace(item.body) + if body == "" || item.alias == "" { + continue + } + result[item.alias] = body + } + return result +} diff --git a/repository/shape/compile/enrich_settings_test.go b/repository/shape/compile/enrich_settings_test.go new file mode 100644 index 000000000..87ecd1f66 --- /dev/null +++ b/repository/shape/compile/enrich_settings_test.go @@ -0,0 +1,26 @@ +package compile + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/viant/datly/repository/shape" + dqlshape "github.com/viant/datly/repository/shape/dql/shape" +) + +func TestExtractRuleSettings_RouteDirectiveOverridesHeader(t *testing.T) { + source := &shape.Source{ + DQL: "/* {\"URI\":\"/v1/api/legacy\",\"Method\":\"GET\"} */\n" + + "#settings($_ = $route('/v1/api/orders', 'POST', 'PATCH'))\n" + + "SELECT 1", + } + + settings := extractRuleSettings(source, &dqlshape.Directives{ + Route: &dqlshape.RouteDirective{ + URI: "/v1/api/orders", + Methods: []string{"POST", "PATCH"}, + }, + }) + assert.Equal(t, "/v1/api/orders", settings.URI) + assert.Equal(t, "POST,PATCH", settings.Method) +} diff --git a/repository/shape/compile/enrich_table.go b/repository/shape/compile/enrich_table.go new file mode 100644 index 000000000..7ba779dc6 --- /dev/null +++ b/repository/shape/compile/enrich_table.go @@ -0,0 +1,331 @@ +package compile + +// enrich_table.go — table-name inference logic extracted from enrich.go. +// All functions here derive a database table name from SQL text, file-system +// sibling files, or embedded SQL references. + +import ( + "os" + "path/filepath" + "strings" + + "github.com/viant/datly/repository/shape" + "github.com/viant/datly/repository/shape/compile/pipeline" + "github.com/viant/datly/repository/shape/plan" +) + +func shouldInferTable(item *plan.View) bool { + if item == nil { + return false + } + name := strings.TrimSpace(item.Name) + table := strings.TrimSpace(item.Table) + if table == "" { + return true + } + if strings.HasPrefix(table, "(") { + return true + } + if normalizedTemplatePlaceholderTable(table) { + return true + } + return strings.EqualFold(name, table) +} + +func normalizedTemplatePlaceholderTable(table string) bool { + if table == "" { + return false + } + parts := strings.Split(table, ".") + if len(parts) < 3 { + return false + } + for i := 0; i < len(parts)-1; i++ { + part := strings.TrimSpace(parts[i]) + if part == "" { + return false + } + for _, ch := range part { + if ch < '0' || ch > '9' { + return false + } + } + } + return true +} + +func inferTableFromSQL(sqlText string, source *shape.Source) string { + sqlText = strings.TrimSpace(sqlText) + if sqlText == "" { + return "" + } + if expr := topLevelFromExpr(sqlText); expr != "" { + if table := tableFromFromExpr(expr, source); table != "" { + return table + } + } + if table := pipeline.InferTableFromSQL(sqlText); table != "" { + if !strings.EqualFold(table, "DQLView") { + return table + } + } + if table := inferFromEmbeddedSQL(sqlText, source); table != "" { + return table + } + return "" +} + +func inferFromEmbeddedSQL(sqlText string, source *shape.Source) string { + ref, ok := findFirstEmbedRef(sqlText) + if !ok { + return "" + } + ref = strings.Trim(ref, `"'`) + if ref == "" { + return "" + } + resolved := resolveEmbedPath(source, ref) + if resolved == "" { + return "" + } + embedded, err := os.ReadFile(resolved) + if err != nil { + return "" + } + queryNode, _, err := pipeline.ParseSelectWithDiagnostic(string(embedded)) + if err != nil || queryNode == nil { + if table := pipeline.InferTableFromSQL(string(embedded)); table != "" && !strings.EqualFold(table, "DQLView") { + return strings.Trim(table, "`\"") + } + return "" + } + _, table, err := pipeline.InferRoot(queryNode, "") + if err != nil || strings.TrimSpace(table) == "" { + return "" + } + if strings.EqualFold(strings.TrimSpace(table), "DQLView") { + return "" + } + return strings.Trim(table, "`\"") +} + +func resolveEmbedPath(source *shape.Source, ref string) string { + if filepath.IsAbs(ref) { + return ref + } + if source == nil || strings.TrimSpace(source.Path) == "" { + return "" + } + base := source.Path + if fi, err := os.Stat(base); err == nil && fi.IsDir() { + return filepath.Clean(filepath.Join(base, ref)) + } + return filepath.Clean(filepath.Join(filepath.Dir(base), ref)) +} + +func inferTableFromSiblingSQL(viewName string, source *shape.Source) string { + viewName = strings.TrimSpace(viewName) + if viewName == "" || source == nil || strings.TrimSpace(source.Path) == "" { + return "" + } + sibling := filepath.Join(filepath.Dir(source.Path), viewName+".sql") + data, err := os.ReadFile(sibling) + if err != nil { + sibling = filepath.Join(filepath.Dir(source.Path), strings.ToLower(viewName)+".sql") + data, err = os.ReadFile(sibling) + } + if err != nil { + return "" + } + return inferTableFromSQL(string(data), source) +} + +func inferTableFromEmbedRef(source *shape.Source, ref string) string { + ref = strings.Trim(strings.TrimSpace(ref), `"'`) + if ref == "" { + return "" + } + resolved := resolveEmbedPath(source, ref) + if resolved == "" { + return "" + } + data, err := os.ReadFile(resolved) + if err != nil { + return "" + } + return pipeline.InferTableFromSQL(string(data)) +} + +// topLevelFromExpr scans sqlText for the first top-level (depth-0) FROM keyword +// and returns the expression that immediately follows it, including subquery parens +// with a trailing alias when present. +func topLevelFromExpr(sqlText string) string { + lower := strings.ToLower(sqlText) + depth := 0 + inSingle := false + inDouble := false + inBacktick := false + for i := 0; i < len(sqlText); i++ { + ch := sqlText[i] + switch ch { + case '\'': + if !inDouble && !inBacktick { + inSingle = !inSingle + } + case '"': + if !inSingle && !inBacktick { + inDouble = !inDouble + } + case '`': + if !inSingle && !inDouble { + inBacktick = !inBacktick + } + case '(': + if !inSingle && !inDouble && !inBacktick { + depth++ + } + case ')': + if !inSingle && !inDouble && !inBacktick && depth > 0 { + depth-- + } + } + if inSingle || inDouble || inBacktick || depth != 0 { + continue + } + if i+6 > len(sqlText) { + break + } + if lower[i:i+4] != "from" { + continue + } + if i > 0 { + prev := lower[i-1] + if (prev >= 'a' && prev <= 'z') || (prev >= '0' && prev <= '9') || prev == '_' { + continue + } + } + j := i + 4 + for j < len(sqlText) && (sqlText[j] == ' ' || sqlText[j] == '\n' || sqlText[j] == '\t' || sqlText[j] == '\r') { + j++ + } + if j >= len(sqlText) { + return "" + } + if sqlText[j] == '(' { + start := j + d := 0 + for ; j < len(sqlText); j++ { + if sqlText[j] == '(' { + d++ + } else if sqlText[j] == ')' { + d-- + if d == 0 { + j++ + break + } + } + } + for j < len(sqlText) && (sqlText[j] == ' ' || sqlText[j] == '\n' || sqlText[j] == '\t' || sqlText[j] == '\r') { + j++ + } + for j < len(sqlText) { + c := sqlText[j] + if !(c == '_' || c == '.' || (c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z') || (c >= '0' && c <= '9')) { + break + } + j++ + } + return strings.TrimSpace(sqlText[start:j]) + } + start := j + for j < len(sqlText) { + c := sqlText[j] + if !(c == '_' || c == '.' || c == '/' || c == '{' || c == '}' || (c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z') || (c >= '0' && c <= '9') || c == '$') { + break + } + j++ + } + return strings.TrimSpace(sqlText[start:j]) + } + return "" +} + +func tableFromFromExpr(fromExpr string, source *shape.Source) string { + fromExpr = strings.TrimSpace(fromExpr) + if fromExpr == "" { + return "" + } + if strings.HasPrefix(fromExpr, "(") { + if table := inferFromEmbeddedSQL(fromExpr, source); table != "" { + return table + } + inner := fromExpr + if idx := strings.LastIndex(inner, ")"); idx > 0 { + inner = strings.TrimSpace(inner[1:idx]) + } + return inferTableFromSQL(inner, source) + } + return strings.Trim(fromExpr, "`\"") +} + +func inferConnector(item *plan.View, source *shape.Source) string { + if item == nil { + return "" + } + path := "" + if source != nil { + path = strings.ToLower(strings.ReplaceAll(source.Path, "\\", "/")) + } + table := strings.ToUpper(strings.TrimSpace(item.Table)) + switch { + case strings.Contains(path, "/dql/system/"): + return "system" + case strings.HasPrefix(table, "CI_") || strings.Contains(table, ".CI_"): + return "ci_ads" + case strings.Contains(path, "/dql/ui/"): + return "sitemgmt" + case strings.Contains(table, "SITE"): + return "sitemgmt" + default: + return "" + } +} + +func normalizeRootViewName(result *plan.Result, sourceName string) { + if result == nil || len(result.Views) == 0 { + return + } + root := result.Views[0] + if root == nil { + return + } + desired := sourceName + if desired == "" { + return + } + current := strings.TrimSpace(root.Name) + if current == "" { + root.Name = desired + root.Path = desired + root.Holder = desired + return + } + if strings.EqualFold(current, desired) { + return + } + suspicious := map[string]bool{ + "and": true, "or": true, "status": true, "value": true, "watching": true, + } + if !suspicious[strings.ToLower(current)] { + return + } + if result.ViewsByName != nil { + delete(result.ViewsByName, root.Name) + } else { + result.ViewsByName = map[string]*plan.View{} + } + root.Name = desired + root.Path = desired + root.Holder = desired + result.ViewsByName[root.Name] = root +} diff --git a/repository/shape/compile/enrich_test.go b/repository/shape/compile/enrich_test.go new file mode 100644 index 000000000..1c532d398 --- /dev/null +++ b/repository/shape/compile/enrich_test.go @@ -0,0 +1,172 @@ +package compile + +import ( + "os" + "path/filepath" + "testing" + + "github.com/stretchr/testify/require" + "github.com/viant/datly/repository/shape" + "github.com/viant/datly/repository/shape/plan" +) + +func TestApplySourceParityEnrichment_RuleConnectorAndSQLURI(t *testing.T) { + source := &shape.Source{ + Path: "/repo/dql/platform/timezone/timezone.dql", + DQL: `/* {"Connector":"ci_ads"} */ SELECT * FROM CI_TIME_ZONE t`, + } + result := &plan.Result{ + Views: []*plan.View{ + {Name: "timezone", Table: "timezone", SQL: "SELECT * FROM CI_TIME_ZONE t"}, + }, + } + + applySourceParityEnrichment(result, source) + + require.Equal(t, "ci_ads", result.Views[0].Connector) + require.Equal(t, "timezone/timezone.sql", result.Views[0].SQLURI) + require.Equal(t, "CI_TIME_ZONE", result.Views[0].Table) +} + +func TestApplySourceParityEnrichment_InferTableFromSubquery(t *testing.T) { + source := &shape.Source{ + Path: "/repo/dql/platform/advertiser/advertiser.dql", + DQL: `SELECT x.* FROM (SELECT a.* FROM CI_ADVERTISER a) x`, + } + result := &plan.Result{ + Views: []*plan.View{ + {Name: "advertiser", Table: "advertiser", SQL: `SELECT x.* FROM (SELECT a.* FROM CI_ADVERTISER a) x`}, + }, + } + + applySourceParityEnrichment(result, source) + + require.Equal(t, "CI_ADVERTISER", result.Views[0].Table) + require.Equal(t, "advertiser/advertiser.sql", result.Views[0].SQLURI) +} + +func TestApplySourceParityEnrichment_InferTableFromEmbed(t *testing.T) { + tempDir := t.TempDir() + dqlDir := filepath.Join(tempDir, "dql", "platform", "timezone") + require.NoError(t, os.MkdirAll(dqlDir, 0o755)) + embedded := filepath.Join(dqlDir, "timezone.sql") + require.NoError(t, os.WriteFile(embedded, []byte(`SELECT tz.ID FROM CI_TIME_ZONE tz`), 0o644)) + source := &shape.Source{ + Path: filepath.Join(dqlDir, "timezone.dql"), + DQL: `SELECT timezone.* FROM (${embed: timezone.sql}) timezone`, + } + result := &plan.Result{ + Views: []*plan.View{ + {Name: "timezone", Table: "timezone", SQL: `SELECT timezone.* FROM (${embed: timezone.sql}) timezone`}, + }, + } + + applySourceParityEnrichment(result, source) + + require.Equal(t, "CI_TIME_ZONE", result.Views[0].Table) + require.Equal(t, "timezone/timezone.sql", result.Views[0].SQLURI) +} + +func TestTopLevelFromExpr_IgnoresNestedFrom(t *testing.T) { + sqlText := `SELECT a.*, EXISTS(SELECT 1 FROM CI_ENTITY_WATCHLIST w WHERE w.ENTITY_ID = a.ID) AS watching FROM (SELECT x.* FROM CI_ADVERTISER x) a` + require.Equal(t, "(SELECT x.* FROM CI_ADVERTISER x) a", topLevelFromExpr(sqlText)) +} + +func TestInferConnector(t *testing.T) { + require.Equal(t, "system", inferConnector(&plan.View{Table: "session"}, &shape.Source{Path: "/repo/dql/system/session/session.dql"})) + require.Equal(t, "ci_ads", inferConnector(&plan.View{Table: "CI_ADVERTISER"}, &shape.Source{Path: "/repo/dql/platform/advertiser/advertiser.dql"})) + require.Equal(t, "sitemgmt", inferConnector(&plan.View{Table: "SITE_MAP"}, &shape.Source{Path: "/repo/dql/ui/agency/detail/campaign.dql"})) +} + +func TestExtractSummarySQL(t *testing.T) { + sqlText := `SELECT b.* FROM CI_BROWSER b +JOIN ( + SELECT COUNT(1) AS CNT + FROM ($View.browser.SQL) t +) summary ON 1=1` + require.Contains(t, extractSummarySQL(sqlText), "COUNT(1)") +} + +func TestInferTableFromSQL_PreservesTemplateQualifiedTable(t *testing.T) { + sqlText := `SELECT SITE_ID FROM ${sitemgmt_project}.${sitemgmt_dataset}.SITE_LIST_MATCH slm` + require.Equal(t, "${sitemgmt_project}.${sitemgmt_dataset}.SITE_LIST_MATCH", inferTableFromSQL(sqlText, nil)) +} + +func TestShouldInferTable_NormalizedTemplatePlaceholderTable(t *testing.T) { + require.True(t, shouldInferTable(&plan.View{Name: "match", Table: "1.1.SITE_LIST_MATCH"})) + require.False(t, shouldInferTable(&plan.View{Name: "match", Table: "SITE_LIST_MATCH"})) +} + +func TestInferTableFromSQL_PathLikeTable(t *testing.T) { + sqlText := `SELECT user_id FROM session/attributes WHERE user_id = 1` + require.Equal(t, "session/attributes", inferTableFromSQL(sqlText, nil)) +} + +func TestApplySourceParityEnrichment_InferTableFromSiblingSQLOnPlaceholderTable(t *testing.T) { + tempDir := t.TempDir() + dqlDir := filepath.Join(tempDir, "dql", "platform", "sitelist") + require.NoError(t, os.MkdirAll(dqlDir, 0o755)) + require.NoError(t, os.WriteFile(filepath.Join(dqlDir, "match.sql"), []byte(`SELECT SITE_ID FROM ${sitemgmt_project}.${sitemgmt_dataset}.SITE_LIST_MATCH slm`), 0o644)) + source := &shape.Source{ + Path: filepath.Join(dqlDir, "match.dql"), + DQL: `SELECT 1`, + } + result := &plan.Result{ + Views: []*plan.View{ + {Name: "match", Table: "1.1.SITE_LIST_MATCH"}, + }, + } + + applySourceParityEnrichment(result, source) + + require.Equal(t, "${sitemgmt_project}.${sitemgmt_dataset}.SITE_LIST_MATCH", result.Views[0].Table) +} + +func TestExtractJoinSubqueryBodies(t *testing.T) { + sqlText := `SELECT sl.* FROM SITE_LIST sl +JOIN ( + SELECT SITE_ID, SITE_LIST_ID FROM ${sitemgmt_project}.${sitemgmt_dataset}.SITE_LIST_MATCH +) match ON match.SITE_LIST_ID = sl.ID +JOIN ( + ${embed: match_rules.sql} + ${predicate.Builder().CombineOr($predicate.FilterGroup(1, "AND")).Build("WHERE")} +) matchRules ON matchRules.SITE_LIST_ID = sl.ID` + bodies := extractJoinSubqueryBodies(sqlText) + require.Contains(t, bodies, "match") + require.Contains(t, bodies["match"], "SITE_LIST_MATCH") + require.Contains(t, bodies, "matchRules") + require.Contains(t, bodies["matchRules"], "${embed: match_rules.sql}") +} + +func TestApplySourceParityEnrichment_Metadata(t *testing.T) { + source := &shape.Source{ + Path: "/repo/dql/platform/tvaffiliatestation/tvaffiliatestation.dql", + DQL: `/* {"Name":"TvAffiliateStation"} */ +SELECT use_connector(tvAffiliateStation, 'ci_ads'), + allow_nulls(tvAffiliateStation), + set_limit(tvAffiliateStation, 0) +FROM CI_TV_AFFILIATE_STATION tvAffiliateStation +JOIN ( + SELECT COUNT(1) AS CNT FROM ($View.tvAffiliateStation.SQL) t +) summary ON 1=1`, + } + result := &plan.Result{ + Views: []*plan.View{ + {Name: "tvAffiliateStation", Table: "CI_TV_AFFILIATE_STATION", SQL: "SELECT * FROM CI_TV_AFFILIATE_STATION tvAffiliateStation"}, + }, + } + hints := extractViewHints(source.DQL) + applyViewHints(result, hints) + applySourceParityEnrichment(result, source) + + require.Len(t, result.Views, 1) + actual := result.Views[0] + require.NotNil(t, actual.AllowNulls) + require.True(t, *actual.AllowNulls) + require.NotNil(t, actual.SelectorNoLimit) + require.True(t, *actual.SelectorNoLimit) + require.Equal(t, "tv", actual.SelectorNamespace) + require.Equal(t, "platform/tvaffiliatestation", actual.Module) + require.Equal(t, "*TvAffiliateStationView", actual.SchemaType) + require.NotEmpty(t, actual.Summary) +} diff --git a/repository/shape/compile/enrich_text.go b/repository/shape/compile/enrich_text.go new file mode 100644 index 000000000..dbf8f2b99 --- /dev/null +++ b/repository/shape/compile/enrich_text.go @@ -0,0 +1,212 @@ +package compile + +// enrich_text.go — low-level text/SQL scanning primitives used by the +// enrichment phase (enrich.go and enrich_table.go). + +import "strings" + +// findSummaryJoinBody locates the body of a JOIN(...) SUMMARY ON 1=1 clause. +func findSummaryJoinBody(input string) (string, bool) { + lower := strings.ToLower(input) + for i := 0; i < len(input); i++ { + if !hasCompileWordAt(lower, i, "join") { + continue + } + pos := skipCompileSpaces(input, i+len("join")) + if pos >= len(input) || input[pos] != '(' { + continue + } + body, end, ok := readCompileParenBody(input, pos) + if !ok { + continue + } + rest := strings.ToLower(input[end+1:]) + rest = strings.Join(strings.Fields(rest), " ") + if strings.HasPrefix(rest, "summary on 1=1") || strings.HasPrefix(rest, "summary on 1 = 1") { + return body, true + } + } + return "", false +} + +// extractLeadingRuleHeaderJSON returns the JSON body of a leading /* {...} */ comment. +func extractLeadingRuleHeaderJSON(input string) (string, bool) { + index := skipCompileSpaces(input, 0) + if index+2 > len(input) || input[index:index+2] != "/*" { + return "", false + } + end := strings.Index(input[index+2:], "*/") + if end < 0 { + return "", false + } + body := strings.TrimSpace(input[index+2 : index+2+end]) + if body == "" || body[0] != '{' || body[len(body)-1] != '}' { + return "", false + } + return body, true +} + +// findFirstEmbedRef returns the path after "embed:" in the first ${embed:…} +// template expression found in input. +func findFirstEmbedRef(input string) (string, bool) { + for i := 0; i < len(input); i++ { + if input[i] != '$' || i+1 >= len(input) || input[i+1] != '{' { + continue + } + body, end, ok := readCompileTemplateExpr(input, i+1) + if !ok { + continue + } + _ = end + trimmed := strings.TrimSpace(body) + if len(trimmed) < len("embed:") || !strings.HasPrefix(strings.ToLower(trimmed), "embed:") { + continue + } + ref := strings.TrimSpace(trimmed[len("embed:"):]) + if ref == "" { + continue + } + return ref, true + } + return "", false +} + +// joinSubquery holds the body and alias of a JOIN(...) AS alias clause. +type joinSubquery struct { + body string + alias string +} + +// scanJoinSubqueries collects all JOIN(body) alias pairs from input. +func scanJoinSubqueries(input string) []joinSubquery { + result := make([]joinSubquery, 0) + lower := strings.ToLower(input) + for i := 0; i < len(input); i++ { + if !hasCompileWordAt(lower, i, "join") { + continue + } + pos := skipCompileSpaces(input, i+len("join")) + if pos >= len(input) || input[pos] != '(' { + continue + } + body, end, ok := readCompileParenBody(input, pos) + if !ok { + continue + } + pos = skipCompileSpaces(input, end+1) + if hasCompileWordAt(lower, pos, "as") { + pos = skipCompileSpaces(input, pos+len("as")) + } + aliasStart := pos + if aliasStart >= len(input) || !isCompileWordStart(input[aliasStart]) { + i = end + continue + } + pos++ + for pos < len(input) && isCompileWordPart(input[pos]) { + pos++ + } + alias := strings.TrimSpace(input[aliasStart:pos]) + if alias != "" { + result = append(result, joinSubquery{body: body, alias: alias}) + } + i = end + } + return result +} + +// parseJoinEmbedRef returns the embed path from a body of the form ${embed:path}. +func parseJoinEmbedRef(body string) (string, bool) { + trimmed := strings.TrimSpace(body) + if !strings.HasPrefix(trimmed, "${") || !strings.HasSuffix(trimmed, "}") { + return "", false + } + inner := strings.TrimSpace(trimmed[2 : len(trimmed)-1]) + if len(inner) < len("embed:") || !strings.HasPrefix(strings.ToLower(inner), "embed:") { + return "", false + } + ref := strings.TrimSpace(inner[len("embed:"):]) + return ref, ref != "" +} + +func readCompileTemplateExpr(input string, openBrace int) (string, int, bool) { + if openBrace <= 0 || openBrace >= len(input) || input[openBrace] != '{' || input[openBrace-1] != '$' { + return "", -1, false + } + for i := openBrace + 1; i < len(input); i++ { + if input[i] == '}' { + return input[openBrace+1 : i], i, true + } + } + return "", -1, false +} + +func readCompileParenBody(input string, openParen int) (string, int, bool) { + depth := 0 + quote := byte(0) + for i := openParen; i < len(input); i++ { + ch := input[i] + if quote != 0 { + if ch == '\\' && i+1 < len(input) { + i++ + continue + } + if ch == quote { + quote = 0 + } + continue + } + if ch == '\'' || ch == '"' { + quote = ch + continue + } + if ch == '(' { + depth++ + continue + } + if ch == ')' { + depth-- + if depth == 0 { + return input[openParen+1 : i], i, true + } + } + } + return "", -1, false +} + +func hasCompileWordAt(lower string, pos int, word string) bool { + if pos < 0 || pos+len(word) > len(lower) { + return false + } + if lower[pos:pos+len(word)] != word { + return false + } + if pos > 0 && isCompileWordPart(lower[pos-1]) { + return false + } + next := pos + len(word) + if next < len(lower) && isCompileWordPart(lower[next]) { + return false + } + return true +} + +func skipCompileSpaces(input string, index int) int { + for index < len(input) { + switch input[index] { + case ' ', '\t', '\n', '\r': + index++ + default: + return index + } + } + return index +} + +func isCompileWordStart(ch byte) bool { + return ch == '_' || (ch >= 'a' && ch <= 'z') || (ch >= 'A' && ch <= 'Z') +} + +func isCompileWordPart(ch byte) bool { + return isCompileWordStart(ch) || (ch >= '0' && ch <= '9') +} diff --git a/repository/shape/compile/hints.go b/repository/shape/compile/hints.go new file mode 100644 index 000000000..a16b855d2 --- /dev/null +++ b/repository/shape/compile/hints.go @@ -0,0 +1,334 @@ +package compile + +import ( + "reflect" + "strconv" + "strings" + + "github.com/viant/datly/repository/shape/plan" +) + +type viewHint struct { + Connector string + AllowNulls *bool + NoLimit *bool +} + +func extractViewHints(dql string) map[string]viewHint { + result := map[string]viewHint{} + for _, call := range scanHintCalls(dql) { + switch call.name { + case "use_connector": + if len(call.args) != 2 { + continue + } + alias := strings.TrimSpace(call.args[0]) + connector := unquote(strings.TrimSpace(call.args[1])) + if !isIdentifier(alias) || !isIdentifier(connector) { + continue + } + hint := result[alias] + hint.Connector = connector + result[alias] = hint + case "allow_nulls": + if len(call.args) != 1 { + continue + } + alias := strings.TrimSpace(call.args[0]) + if !isIdentifier(alias) { + continue + } + hint := result[alias] + value := true + hint.AllowNulls = &value + result[alias] = hint + case "set_limit": + if len(call.args) != 2 { + continue + } + alias := strings.TrimSpace(call.args[0]) + limitRaw := strings.TrimSpace(call.args[1]) + if !isIdentifier(alias) || limitRaw == "" { + continue + } + limit, err := strconv.Atoi(limitRaw) + if err != nil { + continue + } + hint := result[alias] + noLimit := limit == 0 + hint.NoLimit = &noLimit + result[alias] = hint + } + } + return result +} + +type hintCall struct { + name string + args []string +} + +func scanHintCalls(input string) []hintCall { + result := make([]hintCall, 0) + for i := 0; i < len(input); { + if !isIdentifierStart(input[i]) { + i++ + continue + } + start := i + i++ + for i < len(input) && isIdentifierPart(input[i]) { + i++ + } + name := strings.ToLower(input[start:i]) + if name != "use_connector" && name != "allow_nulls" && name != "set_limit" { + continue + } + j := skipSpaces(input, i) + if j >= len(input) || input[j] != '(' { + continue + } + body, end, ok := readCallBody(input, j) + if !ok { + continue + } + result = append(result, hintCall{name: name, args: splitCallArgs(body)}) + i = end + 1 + } + return result +} + +func readCallBody(input string, openParen int) (string, int, bool) { + depth := 0 + quote := byte(0) + for i := openParen; i < len(input); i++ { + ch := input[i] + if quote != 0 { + if ch == '\\' && i+1 < len(input) { + i++ + continue + } + if ch == quote { + quote = 0 + } + continue + } + if ch == '\'' || ch == '"' { + quote = ch + continue + } + if ch == '(' { + depth++ + continue + } + if ch == ')' { + depth-- + if depth == 0 { + return input[openParen+1 : i], i, true + } + } + } + return "", -1, false +} + +func splitCallArgs(input string) []string { + args := make([]string, 0) + current := strings.Builder{} + depth := 0 + quote := byte(0) + for i := 0; i < len(input); i++ { + ch := input[i] + if quote != 0 { + current.WriteByte(ch) + if ch == '\\' && i+1 < len(input) { + i++ + current.WriteByte(input[i]) + continue + } + if ch == quote { + quote = 0 + } + continue + } + if ch == '\'' || ch == '"' { + quote = ch + current.WriteByte(ch) + continue + } + if ch == '(' { + depth++ + current.WriteByte(ch) + continue + } + if ch == ')' { + if depth > 0 { + depth-- + } + current.WriteByte(ch) + continue + } + if ch == ',' && depth == 0 { + args = append(args, strings.TrimSpace(current.String())) + current.Reset() + continue + } + current.WriteByte(ch) + } + if value := strings.TrimSpace(current.String()); value != "" { + args = append(args, value) + } + return args +} + +func isIdentifierStart(ch byte) bool { + return (ch >= 'a' && ch <= 'z') || (ch >= 'A' && ch <= 'Z') || ch == '_' +} + +func isIdentifierPart(ch byte) bool { + return isIdentifierStart(ch) || (ch >= '0' && ch <= '9') +} + +func isIdentifier(value string) bool { + value = strings.TrimSpace(value) + if value == "" || !isIdentifierStart(value[0]) { + return false + } + for i := 1; i < len(value); i++ { + if !isIdentifierPart(value[i]) { + return false + } + } + return true +} + +func unquote(value string) string { + if len(value) >= 2 { + first := value[0] + last := value[len(value)-1] + if (first == '\'' && last == '\'') || (first == '"' && last == '"') { + return value[1 : len(value)-1] + } + } + return value +} + +func skipSpaces(input string, index int) int { + for index < len(input) { + switch input[index] { + case ' ', '\t', '\n', '\r': + index++ + default: + return index + } + } + return index +} + +func appendRelationViews(result *plan.Result, root *plan.View, hints map[string]viewHint) { + if result == nil || root == nil || len(root.Relations) == 0 { + return + } + for _, relation := range root.Relations { + if relation == nil { + continue + } + name := strings.TrimSpace(relation.Ref) + if name == "" { + name = strings.TrimSpace(relation.Name) + } + if name == "" { + continue + } + if len(relation.On) == 0 { + continue + } + if _, exists := result.ViewsByName[name]; exists { + continue + } + table := strings.TrimSpace(relation.Table) + if table == "" { + table = name + } + table = normalizeRelationTable(table) + view := &plan.View{ + Path: name, + Holder: name, + Name: name, + Table: table, + Cardinality: "many", + FieldType: reflect.TypeOf([]map[string]interface{}{}), + ElementType: reflect.TypeOf(map[string]interface{}{}), + } + result.Views = append(result.Views, view) + result.ViewsByName[name] = view + } +} + +func applyViewHints(result *plan.Result, hints map[string]viewHint) { + if result == nil || len(result.Views) == 0 { + return + } + if len(hints) == 0 { + return + } + for _, item := range result.Views { + if item == nil { + continue + } + for _, key := range []string{item.Name, item.Holder} { + key = strings.TrimSpace(key) + if key == "" { + continue + } + hint, ok := hints[key] + if !ok { + continue + } + if item.Connector == "" && hint.Connector != "" { + item.Connector = hint.Connector + } + if item.AllowNulls == nil && hint.AllowNulls != nil { + value := *hint.AllowNulls + item.AllowNulls = &value + } + if item.SelectorNoLimit == nil && hint.NoLimit != nil { + value := *hint.NoLimit + item.SelectorNoLimit = &value + } + } + } +} + +func normalizeRelationTable(table string) string { + table = strings.TrimSpace(table) + if table == "" { + return table + } + lower := strings.ToLower(table) + fromIdx := strings.Index(lower, " from ") + if fromIdx == -1 { + return table + } + tail := strings.TrimSpace(table[fromIdx+6:]) + if tail == "" { + return table + } + stop := len(tail) + for i := 0; i < len(tail); i++ { + switch tail[i] { + case ' ', '\t', '\n', '\r', ')': + stop = i + i = len(tail) + } + } + if stop == 0 { + return table + } + normalized := strings.TrimSpace(tail[:stop]) + normalized = strings.Trim(normalized, "`\"") + if normalized == "" { + return table + } + return normalized +} diff --git a/repository/shape/compile/hints_test.go b/repository/shape/compile/hints_test.go new file mode 100644 index 000000000..c40470e0f --- /dev/null +++ b/repository/shape/compile/hints_test.go @@ -0,0 +1,54 @@ +package compile + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/viant/datly/repository/shape/plan" +) + +func TestExtractViewHints_WithQuotedConnector(t *testing.T) { + dql := "SELECT use_connector(match, 'bq_sitemgmt_match'), use_connector(site, \"ci_ads\"), allow_nulls(match), set_limit(match, 0)" + hints := extractViewHints(dql) + require.Len(t, hints, 2) + assert.Equal(t, "bq_sitemgmt_match", hints["match"].Connector) + assert.Equal(t, "ci_ads", hints["site"].Connector) + require.NotNil(t, hints["match"].AllowNulls) + assert.True(t, *hints["match"].AllowNulls) + require.NotNil(t, hints["match"].NoLimit) + assert.True(t, *hints["match"].NoLimit) +} + +func TestExtractViewHints_MixedCaseAndUnquotedConnector(t *testing.T) { + dql := "SELECT USE_CONNECTOR(match, ci_ads), Allow_Nulls(match), set_limit(match, -1)" + hints := extractViewHints(dql) + require.Contains(t, hints, "match") + assert.Equal(t, "ci_ads", hints["match"].Connector) + require.NotNil(t, hints["match"].AllowNulls) + assert.True(t, *hints["match"].AllowNulls) + require.NotNil(t, hints["match"].NoLimit) + assert.False(t, *hints["match"].NoLimit) +} + +func TestApplyViewHints_Metadata(t *testing.T) { + trueValue := true + result := &plan.Result{ + Views: []*plan.View{ + {Name: "match", Table: "MATCH"}, + }, + } + applyViewHints(result, map[string]viewHint{ + "match": { + Connector: "ci_ads", + AllowNulls: &trueValue, + NoLimit: &trueValue, + }, + }) + require.Len(t, result.Views, 1) + assert.Equal(t, "ci_ads", result.Views[0].Connector) + require.NotNil(t, result.Views[0].AllowNulls) + assert.True(t, *result.Views[0].AllowNulls) + require.NotNil(t, result.Views[0].SelectorNoLimit) + assert.True(t, *result.Views[0].SelectorNoLimit) +} diff --git a/repository/shape/compile/legacy_adapter.go b/repository/shape/compile/legacy_adapter.go new file mode 100644 index 000000000..c91409b01 --- /dev/null +++ b/repository/shape/compile/legacy_adapter.go @@ -0,0 +1,319 @@ +package compile + +import ( + "os" + "path/filepath" + "reflect" + "sort" + "strings" + + "github.com/viant/datly/repository/shape" + "github.com/viant/datly/repository/shape/plan" + "gopkg.in/yaml.v3" +) + +func resolveGeneratedCompanionDQL(source *shape.Source) string { + if source == nil || strings.TrimSpace(source.Path) == "" { + return "" + } + settings := extractRuleSettings(source, nil) + typeExpr := strings.TrimSpace(settings.Type) + if typeExpr == "" { + return "" + } + typeExpr = strings.TrimSuffix(typeExpr, ".Handler") + typeExpr = strings.Trim(typeExpr, `"'`) + if typeExpr == "" { + return "" + } + dir := filepath.Dir(source.Path) + baseTypePath := filepath.FromSlash(typeExpr) + stem := filepath.Base(baseTypePath) + candidates := []string{ + filepath.Join(dir, "gen", baseTypePath+".dql"), + filepath.Join(dir, "gen", baseTypePath+".sql"), + filepath.Join(dir, "gen", stem+".dql"), + filepath.Join(dir, "gen", stem+".sql"), + } + for _, candidate := range candidates { + data, err := os.ReadFile(candidate) + if err != nil { + continue + } + content := strings.TrimSpace(string(data)) + if content != "" { + return content + } + } + return "" +} + +func resolveLegacyRouteViews(source *shape.Source) []*plan.View { + return resolveLegacyRouteViewsWithLayout(source, defaultCompilePathLayout()) +} + +func resolveLegacyRouteViewsWithLayout(source *shape.Source, layout compilePathLayout) []*plan.View { + if source == nil || strings.TrimSpace(source.Path) == "" { + return nil + } + platformRoot, relativeDir, stem, ok := platformPathParts(source.Path, layout) + if !ok { + return nil + } + settings := extractRuleSettings(source, nil) + typeExpr := strings.TrimSpace(settings.Type) + typeExpr = strings.Trim(typeExpr, `"'`) + typeExpr = strings.TrimSuffix(typeExpr, ".Handler") + typeStem := "" + if typeExpr != "" { + typeStem = filepath.Base(filepath.FromSlash(typeExpr)) + } + routesRoot := joinRelativePath(platformRoot, layout.routesRelative) + routesBase := filepath.Join(routesRoot, filepath.FromSlash(relativeDir)) + legacyMeta := []legacyViewMeta(nil) + for _, candidateYAML := range legacyRouteYAMLCandidates(routesBase, stem, typeStem) { + legacyMeta = loadLegacyRouteViewMeta(candidateYAML) + if len(legacyMeta) > 0 { + break + } + } + searchDirs := []string{ + filepath.Join(routesBase, typeStem, stem), + filepath.Join(routesBase, typeStem), + filepath.Join(routesBase, stem, stem), + filepath.Join(routesBase, stem), + routesBase, + } + var sqlFiles []string + for _, dir := range searchDirs { + entries, err := os.ReadDir(dir) + if err != nil { + continue + } + for _, entry := range entries { + if entry.IsDir() || !strings.HasSuffix(strings.ToLower(entry.Name()), ".sql") { + continue + } + sqlFiles = append(sqlFiles, filepath.Join(dir, entry.Name())) + } + if len(sqlFiles) > 0 { + break + } + } + if len(sqlFiles) == 0 { + return nil + } + sort.Strings(sqlFiles) + result := make([]*plan.View, 0, len(sqlFiles)) + rootIndex := -1 + for _, sqlFile := range sqlFiles { + name := strings.TrimSuffix(filepath.Base(sqlFile), filepath.Ext(sqlFile)) + if name == "" { + continue + } + data, err := os.ReadFile(sqlFile) + if err != nil { + continue + } + sqlText := string(data) + table := "" + if name != stem { + table = inferTableFromSQL(sqlText, source) + } + connector := strings.TrimSpace(settings.Connector) + if connector == "" { + connector = strings.TrimSpace(source.Connector) + } + if connector == "" { + connector = inferConnector(&plan.View{Table: table}, source) + } + viewItem := &plan.View{ + Path: name, + Holder: name, + Name: name, + Table: table, + SQL: sqlText, + SQLURI: filepath.ToSlash(filepath.Join(stem, name+".sql")), + Connector: connector, + Cardinality: "many", + FieldType: reflect.TypeOf([]map[string]interface{}{}), + ElementType: reflect.TypeOf(map[string]interface{}{}), + } + if meta, ok := lookupLegacyViewMeta(legacyMeta, name); ok { + if strings.TrimSpace(meta.Table) != "" { + viewItem.Table = strings.TrimSpace(meta.Table) + } + if strings.TrimSpace(meta.Connector) != "" { + viewItem.Connector = strings.TrimSpace(meta.Connector) + } + if strings.TrimSpace(meta.SQLURI) != "" { + viewItem.SQLURI = strings.TrimSpace(meta.SQLURI) + } + } + if name == stem { + rootIndex = len(result) + } + result = append(result, viewItem) + } + if len(result) == 0 { + return nil + } + if rootIndex > 0 { + root := result[rootIndex] + copy(result[1:rootIndex+1], result[0:rootIndex]) + result[0] = root + } + if result[0].Name != stem { + rootConnector := result[0].Connector + result = append([]*plan.View{{ + Path: stem, + Holder: stem, + Name: stem, + Table: "", + SQLURI: filepath.ToSlash(filepath.Join(stem, stem+".sql")), + Connector: rootConnector, + Cardinality: "many", + FieldType: reflect.TypeOf([]map[string]interface{}{}), + ElementType: reflect.TypeOf(map[string]interface{}{}), + }}, result...) + } + result[0].Table = "" + result[0].Name = stem + result[0].Holder = stem + result[0].Path = stem + if meta, ok := lookupLegacyViewMeta(legacyMeta, stem); ok { + if strings.TrimSpace(meta.Table) != "" { + result[0].Table = strings.TrimSpace(meta.Table) + } + if strings.TrimSpace(meta.Connector) != "" { + result[0].Connector = strings.TrimSpace(meta.Connector) + } + if strings.TrimSpace(meta.SQLURI) != "" { + result[0].SQLURI = strings.TrimSpace(meta.SQLURI) + } + } + if result[0].SQLURI == "" { + result[0].SQLURI = filepath.ToSlash(filepath.Join(stem, stem+".sql")) + } + return result +} + +type legacyViewMeta struct { + Name string + Table string + Connector string + SQLURI string +} + +func loadLegacyRouteViewMeta(yamlPath string) []legacyViewMeta { + data, err := os.ReadFile(yamlPath) + if err != nil { + return nil + } + var payload struct { + Resource struct { + Views []struct { + Name string `yaml:"Name"` + Table string `yaml:"Table"` + Connector struct { + Ref string `yaml:"Ref"` + } `yaml:"Connector"` + Template struct { + SourceURL string `yaml:"SourceURL"` + } `yaml:"Template"` + } `yaml:"Views"` + } `yaml:"Resource"` + } + if err = yaml.Unmarshal(data, &payload); err != nil { + return nil + } + result := make([]legacyViewMeta, 0, len(payload.Resource.Views)) + for _, item := range payload.Resource.Views { + result = append(result, legacyViewMeta{ + Name: strings.TrimSpace(item.Name), + Table: strings.TrimSpace(item.Table), + Connector: strings.TrimSpace(item.Connector.Ref), + SQLURI: strings.TrimSpace(item.Template.SourceURL), + }) + } + return result +} + +func lookupLegacyViewMeta(items []legacyViewMeta, name string) (legacyViewMeta, bool) { + name = strings.TrimSpace(name) + if name == "" { + return legacyViewMeta{}, false + } + for _, item := range items { + if strings.EqualFold(strings.TrimSpace(item.Name), name) { + return item, true + } + } + return legacyViewMeta{}, false +} + +func legacyRouteYAMLCandidates(routesBase, stem, typeStem string) []string { + stemFileVariants := routeStemAlternatives(stem) + stemDirVariants := routeStemAlternatives(stem) + typeVariants := routeStemAlternatives(typeStem) + var result []string + seen := map[string]bool{} + appendCandidate := func(path string) { + path = filepath.Clean(path) + if path == "." || path == "" || seen[path] { + return + } + seen[path] = true + result = append(result, path) + } + for _, fileStem := range stemFileVariants { + appendCandidate(filepath.Join(routesBase, fileStem+".yaml")) + for _, dirStem := range stemDirVariants { + appendCandidate(filepath.Join(routesBase, dirStem, fileStem+".yaml")) + } + for _, itemTypeStem := range typeVariants { + if strings.TrimSpace(itemTypeStem) == "" { + continue + } + appendCandidate(filepath.Join(routesBase, itemTypeStem, fileStem+".yaml")) + } + } + return result +} + +func routeStemAlternatives(value string) []string { + value = strings.TrimSpace(value) + if value == "" { + return nil + } + alts := []string{value} + dashed := strings.ReplaceAll(value, "_", "-") + if dashed != value { + alts = append(alts, dashed) + } + return alts +} + +func platformPathParts(sourcePath string, layout compilePathLayout) (platformRoot, relativeDir, stem string, ok bool) { + sourcePath = filepath.Clean(strings.TrimSpace(sourcePath)) + if sourcePath == "" { + return "", "", "", false + } + normalized := filepath.ToSlash(sourcePath) + marker := layout.dqlMarker + if marker == "" { + marker = defaultCompilePathLayout().dqlMarker + } + idx := strings.Index(normalized, marker) + if idx == -1 { + return "", "", "", false + } + platformRoot = sourcePath[:idx] + relative := strings.TrimPrefix(normalized[idx+len(marker):], "/") + relativeDir = filepath.Dir(relative) + stem = strings.TrimSuffix(filepath.Base(sourcePath), filepath.Ext(sourcePath)) + if strings.TrimSpace(stem) == "" { + return "", "", "", false + } + return platformRoot, relativeDir, stem, true +} diff --git a/repository/shape/compile/pathlayout.go b/repository/shape/compile/pathlayout.go new file mode 100644 index 000000000..a1bc081e2 --- /dev/null +++ b/repository/shape/compile/pathlayout.go @@ -0,0 +1,67 @@ +package compile + +import ( + "path/filepath" + "strings" + + "github.com/viant/datly/repository/shape" +) + +type compilePathLayout struct { + dqlMarker string + routesRelative string +} + +func defaultCompilePathLayout() compilePathLayout { + return compilePathLayout{ + dqlMarker: "/dql/", + routesRelative: "repo/dev/Datly/routes", + } +} + +func newCompilePathLayout(opts *shape.CompileOptions) compilePathLayout { + ret := defaultCompilePathLayout() + if opts == nil { + return ret + } + if marker := normalizeDQLMarker(opts.DQLPathMarker); marker != "" { + ret.dqlMarker = marker + } + if rel := normalizeRoutesRelative(opts.RoutesRelativePath); rel != "" { + ret.routesRelative = rel + } + return ret +} + +func normalizeDQLMarker(input string) string { + input = strings.TrimSpace(strings.ReplaceAll(input, "\\", "/")) + if input == "" { + return "" + } + input = strings.Trim(input, "/") + if input == "" { + return "" + } + return "/" + input + "/" +} + +func normalizeRoutesRelative(input string) string { + input = strings.TrimSpace(strings.ReplaceAll(input, "\\", "/")) + input = strings.Trim(input, "/") + if input == "" { + return "" + } + return input +} + +func joinRelativePath(base string, rel string) string { + rel = normalizeRoutesRelative(rel) + if rel == "" { + return base + } + parts := strings.Split(rel, "/") + args := make([]string, 0, len(parts)+1) + args = append(args, base) + args = append(args, parts...) + return filepath.Join(args...) +} diff --git a/repository/shape/compile/pipeline/diag.go b/repository/shape/compile/pipeline/diag.go new file mode 100644 index 000000000..45bf78568 --- /dev/null +++ b/repository/shape/compile/pipeline/diag.go @@ -0,0 +1,47 @@ +package pipeline + +import ( + "unicode/utf8" + + dqlshape "github.com/viant/datly/repository/shape/dql/shape" + dqlstmt "github.com/viant/datly/repository/shape/dql/statement" +) + +func StatementSpan(sqlText string, stmt *dqlstmt.Statement) dqlshape.Span { + if stmt == nil { + return pointSpan(sqlText, 0) + } + return pointSpan(sqlText, stmt.Start) +} + +func pointSpan(text string, offset int) dqlshape.Span { + start := positionAt(text, offset) + end := start + return dqlshape.Span{Start: start, End: end} +} + +func positionAt(text string, offset int) dqlshape.Position { + if offset < 0 { + offset = 0 + } + if offset > len(text) { + offset = len(text) + } + line := 1 + char := 1 + index := 0 + for index < offset { + r, width := utf8.DecodeRuneInString(text[index:]) + if width <= 0 { + break + } + index += width + if r == '\n' { + line++ + char = 1 + } else { + char++ + } + } + return dqlshape.Position{Offset: offset, Line: line, Char: char} +} diff --git a/repository/shape/compile/pipeline/exec.go b/repository/shape/compile/pipeline/exec.go new file mode 100644 index 000000000..6bf144888 --- /dev/null +++ b/repository/shape/compile/pipeline/exec.go @@ -0,0 +1,109 @@ +package pipeline + +import ( + "reflect" + "strings" + + dqldiag "github.com/viant/datly/repository/shape/dql/diag" + dqlshape "github.com/viant/datly/repository/shape/dql/shape" + dqlstmt "github.com/viant/datly/repository/shape/dql/statement" + "github.com/viant/datly/repository/shape/plan" + "github.com/viant/sqlparser" +) + +func BuildExec(sourceName, sqlText string, statements dqlstmt.Statements) (*plan.View, []*dqlshape.Diagnostic) { + name := SanitizeName(sourceName) + if name == "" { + name = "DQLView" + } + tables := statements.DMLTables(sqlText) + table := name + if len(tables) > 0 { + table = tables[0] + } + fieldType := reflect.TypeOf([]map[string]interface{}{}) + elementType := reflect.TypeOf(map[string]interface{}{}) + view := &plan.View{ + Path: name, + Holder: name, + Name: name, + Mode: "SQLExec", + Table: table, + SQL: sqlText, + Cardinality: "many", + FieldType: fieldType, + ElementType: elementType, + } + return view, ValidateExecStatements(sqlText, statements) +} + +func ValidateExecStatements(sqlText string, statements dqlstmt.Statements) []*dqlshape.Diagnostic { + var result []*dqlshape.Diagnostic + for _, stmt := range statements { + if stmt == nil || !stmt.IsExec { + continue + } + body := strings.TrimSpace(sqlText[stmt.Start:stmt.End]) + if body == "" { + continue + } + lower := strings.ToLower(body) + span := StatementSpan(sqlText, stmt) + switch { + case stmt.Kind == dqlstmt.KindService: + if firstQuoted(body) == "" { + result = append(result, &dqlshape.Diagnostic{ + Code: dqldiag.CodeDMLServiceArg, + Severity: dqlshape.SeverityError, + Message: "service DML call is missing quoted table argument", + Hint: "use $sql.Insert(\"TABLE\", ...) or $sql.Update(\"TABLE\", ...)", + Span: span, + }) + } + case strings.HasPrefix(lower, "insert"): + if _, err := sqlparser.ParseInsert(body); err != nil { + result = append(result, &dqlshape.Diagnostic{ + Code: dqldiag.CodeDMLInsert, + Severity: dqlshape.SeverityError, + Message: strings.TrimSpace(err.Error()), + Hint: "fix INSERT statement syntax", + Span: span, + }) + } + case strings.HasPrefix(lower, "update"): + if _, err := sqlparser.ParseUpdate(body); err != nil { + result = append(result, &dqlshape.Diagnostic{ + Code: dqldiag.CodeDMLUpdate, + Severity: dqlshape.SeverityError, + Message: strings.TrimSpace(err.Error()), + Hint: "fix UPDATE statement syntax", + Span: span, + }) + } + case strings.HasPrefix(lower, "delete"): + if _, err := sqlparser.ParseDelete(body); err != nil { + result = append(result, &dqlshape.Diagnostic{ + Code: dqldiag.CodeDMLDelete, + Severity: dqlshape.SeverityError, + Message: strings.TrimSpace(err.Error()), + Hint: "fix DELETE statement syntax", + Span: span, + }) + } + } + } + return result +} + +func firstQuoted(input string) string { + index := strings.Index(input, `"`) + if index == -1 { + return "" + } + tail := input[index+1:] + end := strings.Index(tail, `"`) + if end == -1 { + return "" + } + return strings.TrimSpace(tail[:end]) +} diff --git a/repository/shape/compile/pipeline/exec_test.go b/repository/shape/compile/pipeline/exec_test.go new file mode 100644 index 000000000..8c70fc74d --- /dev/null +++ b/repository/shape/compile/pipeline/exec_test.go @@ -0,0 +1,26 @@ +package pipeline + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + dqldiag "github.com/viant/datly/repository/shape/dql/diag" + dqlstmt "github.com/viant/datly/repository/shape/dql/statement" +) + +func TestBuildExec(t *testing.T) { + sqlText := "INSERT INTO ORDERS(id) VALUES (1)" + view, diags := BuildExec("orders_exec", sqlText, dqlstmt.New(sqlText)) + require.NotNil(t, view) + assert.Equal(t, "ORDERS", view.Table) + assert.Equal(t, "many", view.Cardinality) + assert.Empty(t, diags) +} + +func TestValidateExecStatements_ServiceArg(t *testing.T) { + sqlText := "$sql.Insert($rec)" + diags := ValidateExecStatements(sqlText, dqlstmt.New(sqlText)) + require.NotEmpty(t, diags) + assert.Equal(t, dqldiag.CodeDMLServiceArg, diags[0].Code) +} diff --git a/repository/shape/compile/pipeline/infer.go b/repository/shape/compile/pipeline/infer.go new file mode 100644 index 000000000..7ad212431 --- /dev/null +++ b/repository/shape/compile/pipeline/infer.go @@ -0,0 +1,247 @@ +package pipeline + +import ( + "fmt" + "reflect" + "strings" + + "github.com/viant/sqlparser" + "github.com/viant/sqlparser/query" +) + +func InferRoot(queryNode *query.Select, fallback string) (string, string, error) { + name := SanitizeName(fallback) + if name == "" { + name = "DQLView" + } + if queryNode == nil { + return name, name, nil + } + if alias := SanitizeName(queryNode.From.Alias); alias != "" { + name = alias + } + table := "" + if queryNode.From.X != nil { + table = strings.TrimSpace(sqlparser.Stringify(queryNode.From.X)) + } + if name == "" || name == SanitizeName(fallback) { + if subAlias := inferSubqueryAlias(table); subAlias != "" { + name = subAlias + } + } + if table == "" || strings.HasPrefix(table, "(") { + if inferred := inferSubqueryTable(table); inferred != "" { + table = inferred + } else { + table = name + } + } + if name == "" { + return "", "", fmt.Errorf("shape compile: failed to infer view name") + } + return name, table, nil +} + +func inferSubqueryAlias(fromExpr string) string { + fromExpr = strings.TrimSpace(fromExpr) + if fromExpr == "" || !strings.HasPrefix(fromExpr, "(") { + return "" + } + depth := 0 + closeIdx := -1 + for i := 0; i < len(fromExpr); i++ { + switch fromExpr[i] { + case '(': + depth++ + case ')': + depth-- + if depth == 0 { + closeIdx = i + i = len(fromExpr) + } + } + } + if closeIdx == -1 || closeIdx+1 >= len(fromExpr) { + return "" + } + rest := strings.TrimSpace(fromExpr[closeIdx+1:]) + restLower := strings.ToLower(rest) + if strings.HasPrefix(restLower, "as ") { + rest = strings.TrimSpace(rest[3:]) + } + if rest == "" { + return "" + } + end := 0 + for end < len(rest) { + c := rest[end] + if !(c == '_' || (c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z') || (end > 0 && c >= '0' && c <= '9')) { + break + } + end++ + } + if end == 0 { + return "" + } + return SanitizeName(rest[:end]) +} + +func inferSubqueryTable(fromExpr string) string { + inner, ok := extractSubqueryBody(fromExpr) + if !ok { + return "" + } + normalized := normalizeParserSQL(inner) + queryNode, _, err := ParseSelectWithDiagnostic(normalized) + if err != nil || queryNode == nil { + return "" + } + _, table, err := InferRoot(queryNode, "") + if err != nil { + return "" + } + table = strings.TrimSpace(strings.Trim(table, "`\"")) + if strings.EqualFold(table, "DQLView") { + return "" + } + return table +} + +func extractSubqueryBody(fromExpr string) (string, bool) { + fromExpr = strings.TrimSpace(fromExpr) + if !strings.HasPrefix(fromExpr, "(") { + return "", false + } + depth := 0 + for i := 0; i < len(fromExpr); i++ { + switch fromExpr[i] { + case '(': + depth++ + case ')': + depth-- + if depth == 0 { + if i <= 1 { + return "", false + } + return strings.TrimSpace(fromExpr[1:i]), true + } + } + } + return "", false +} + +func InferProjectionType(queryNode *query.Select) (reflect.Type, reflect.Type, string) { + if queryNode == nil || len(queryNode.List) == 0 || queryNode.List.IsStarExpr() { + return reflect.TypeOf([]map[string]interface{}{}), reflect.TypeOf(map[string]interface{}{}), "many" + } + fields := make([]reflect.StructField, 0, len(queryNode.List)) + used := map[string]int{} + for index, item := range queryNode.List { + column := sqlparser.NewColumn(item) + columnName := strings.TrimSpace(column.Identity()) + if columnName == "" { + columnName = fmt.Sprintf("col_%d", index+1) + } + fieldName := ExportedName(columnName) + if fieldName == "" { + fieldName = fmt.Sprintf("Col%d", index+1) + } + if count := used[fieldName]; count > 0 { + fieldName = fmt.Sprintf("%s%d", fieldName, count+1) + } + used[fieldName]++ + + typ := parseColumnType(column.Type) + fields = append(fields, reflect.StructField{ + Name: fieldName, + Type: typ, + Tag: reflect.StructTag(fmt.Sprintf(`json:"%s,omitempty" sqlx:"name=%s"`, strings.ToLower(fieldName), columnName)), + }) + } + element := reflect.StructOf(fields) + return reflect.SliceOf(element), element, "many" +} + +func SanitizeName(value string) string { + value = strings.TrimSpace(value) + if value == "" { + return "" + } + if value == strings.ToUpper(value) { + value = strings.ToLower(value) + } + value = replaceNonWordWithUnderscore(value) + value = strings.Trim(value, "_") + if value == "" { + return "" + } + if value[0] >= '0' && value[0] <= '9' { + value = "V_" + value + } + return value +} + +func ExportedName(value string) string { + value = replaceNonWordWithUnderscore(strings.TrimSpace(value)) + value = strings.Trim(value, "_") + if value == "" { + return "" + } + parts := strings.Split(strings.ToLower(value), "_") + for i, item := range parts { + if item == "" { + continue + } + parts[i] = strings.ToUpper(item[:1]) + item[1:] + } + name := strings.Join(parts, "") + if name == "" { + return "" + } + if name[0] >= '0' && name[0] <= '9' { + name = "N" + name + } + return name +} + +func replaceNonWordWithUnderscore(value string) string { + if value == "" { + return "" + } + var b strings.Builder + b.Grow(len(value)) + lastUnderscore := false + for i := 0; i < len(value); i++ { + ch := value[i] + isWord := ch == '_' || (ch >= 'a' && ch <= 'z') || (ch >= 'A' && ch <= 'Z') || (ch >= '0' && ch <= '9') + if isWord { + b.WriteByte(ch) + lastUnderscore = false + continue + } + if !lastUnderscore { + b.WriteByte('_') + lastUnderscore = true + } + } + return b.String() +} + +func parseColumnType(dataType string) reflect.Type { + switch strings.ToLower(strings.TrimSpace(dataType)) { + case "", "string", "text", "varchar", "char", "uuid", "json", "jsonb": + return reflect.TypeOf("") + case "bool", "boolean": + return reflect.TypeOf(false) + case "int", "int32", "smallint", "integer": + return reflect.TypeOf(int(0)) + case "int64", "bigint": + return reflect.TypeOf(int64(0)) + case "float", "float32", "real": + return reflect.TypeOf(float32(0)) + case "float64", "double", "numeric", "decimal": + return reflect.TypeOf(float64(0)) + default: + return reflect.TypeOf("") + } +} diff --git a/repository/shape/compile/pipeline/infer_test.go b/repository/shape/compile/pipeline/infer_test.go new file mode 100644 index 000000000..748fcded4 --- /dev/null +++ b/repository/shape/compile/pipeline/infer_test.go @@ -0,0 +1,42 @@ +package pipeline + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/viant/sqlparser" +) + +func TestInferSubqueryAlias(t *testing.T) { + assert.Equal(t, "session", inferSubqueryAlias("(SELECT * FROM session) session JOIN (SELECT * FROM attr) attribute ON attribute.id = session.id")) + assert.Equal(t, "x", inferSubqueryAlias("(SELECT 1) AS x")) + assert.Equal(t, "publisherglobaloverride", inferSubqueryAlias(`( + SELECT MIN(g.BUSINESS_MODEL_ID) AS BUSINESS_MODEL_ID + FROM CI_GLOBAL_PUBLISHER_OVERRIDE g +) publisherglobaloverride`)) + assert.Equal(t, "", inferSubqueryAlias("orders o")) +} + +func TestSanitizeName_AllCapsToLower(t *testing.T) { + assert.Equal(t, "value", SanitizeName("VALUE")) + assert.Equal(t, "status", SanitizeName("STATUS")) +} + +func TestInferSubqueryTable(t *testing.T) { + assert.Equal(t, "CI_ADVERTISER", inferSubqueryTable("(SELECT a.* FROM CI_ADVERTISER a) advertiser")) + assert.Equal(t, "", inferSubqueryTable("orders o")) +} + +func TestInferRoot_SubqueryFrom(t *testing.T) { + queryNode, err := sqlparser.ParseQuery(`SELECT advertiser.* FROM (SELECT a.* FROM CI_ADVERTISER a) advertiser`) + assert.NoError(t, err) + name, table, err := InferRoot(queryNode, "fallback") + assert.NoError(t, err) + assert.Equal(t, "advertiser", name) + assert.Equal(t, "CI_ADVERTISER", table) +} + +func TestInferTableFromSQL_ResolvesTopLevelFrom(t *testing.T) { + sqlText := `SELECT a.*, EXISTS(SELECT 1 FROM CI_ENTITY_WATCHLIST w WHERE w.ENTITY_ID = a.ID) AS watching FROM (SELECT x.* FROM CI_ADVERTISER x) a` + assert.Equal(t, "CI_ADVERTISER", InferTableFromSQL(sqlText)) +} diff --git a/repository/shape/compile/pipeline/parse.go b/repository/shape/compile/pipeline/parse.go new file mode 100644 index 000000000..c897454a4 --- /dev/null +++ b/repository/shape/compile/pipeline/parse.go @@ -0,0 +1,62 @@ +package pipeline + +import ( + "strings" + + dqldiag "github.com/viant/datly/repository/shape/dql/diag" + dqlshape "github.com/viant/datly/repository/shape/dql/shape" + "github.com/viant/parsly" + "github.com/viant/sqlparser" + "github.com/viant/sqlparser/query" +) + +func ParseSelectWithDiagnostic(sqlText string) (*query.Select, *dqlshape.Diagnostic, error) { + sqlText = trimLeadingBlockComments(sqlText) + var diagnostic *dqlshape.Diagnostic + onError := func(err error, cur *parsly.Cursor, _ interface{}) error { + offset := 0 + if cur != nil { + offset = cur.Pos + } + if offset < 0 { + offset = 0 + } + diagnostic = &dqlshape.Diagnostic{ + Code: dqldiag.CodeParseSyntax, + Severity: dqlshape.SeverityError, + Message: strings.TrimSpace(err.Error()), + Hint: "check SQL syntax near the reported location", + Span: pointSpan(sqlText, offset), + } + return err + } + result, err := sqlparser.ParseQuery(sqlText, sqlparser.WithErrorHandler(onError)) + if err != nil { + if diagnostic == nil { + diagnostic = &dqlshape.Diagnostic{ + Code: dqldiag.CodeParseSyntax, + Severity: dqlshape.SeverityError, + Message: strings.TrimSpace(err.Error()), + Hint: "check SQL syntax near the reported location", + Span: pointSpan(sqlText, 0), + } + } + return nil, diagnostic, err + } + if result == nil { + return nil, nil, nil + } + return result, nil, nil +} + +func trimLeadingBlockComments(sqlText string) string { + remaining := strings.TrimLeft(sqlText, " \t\r\n") + for strings.HasPrefix(remaining, "/*") { + end := strings.Index(remaining, "*/") + if end == -1 { + return remaining + } + remaining = strings.TrimLeft(remaining[end+2:], " \t\r\n") + } + return remaining +} diff --git a/repository/shape/compile/pipeline/parse_test.go b/repository/shape/compile/pipeline/parse_test.go new file mode 100644 index 000000000..69292fc8a --- /dev/null +++ b/repository/shape/compile/pipeline/parse_test.go @@ -0,0 +1,35 @@ +package pipeline + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + dqldiag "github.com/viant/datly/repository/shape/dql/diag" +) + +func TestParseSelectWithDiagnostic_OK(t *testing.T) { + queryNode, diag, err := ParseSelectWithDiagnostic("SELECT id FROM orders o") + require.NoError(t, err) + require.Nil(t, diag) + require.NotNil(t, queryNode) + assert.Equal(t, "o", queryNode.From.Alias) +} + +func TestParseSelectWithDiagnostic_Syntax(t *testing.T) { + queryNode, diag, err := ParseSelectWithDiagnostic("SELECT id FROM orders WHERE (") + require.Error(t, err) + require.Nil(t, queryNode) + require.NotNil(t, diag) + assert.Equal(t, dqldiag.CodeParseSyntax, diag.Code) + assert.Equal(t, 1, diag.Span.Start.Line) + assert.Greater(t, diag.Span.Start.Char, 1) +} + +func TestParseSelectWithDiagnostic_LeadingBlockComment(t *testing.T) { + queryNode, diag, err := ParseSelectWithDiagnostic("/* {\"URI\":\"/x\"} */\nSELECT id FROM orders o") + require.NoError(t, err) + require.Nil(t, diag) + require.NotNil(t, queryNode) + assert.Equal(t, "o", queryNode.From.Alias) +} diff --git a/repository/shape/compile/pipeline/policy.go b/repository/shape/compile/pipeline/policy.go new file mode 100644 index 000000000..432bcd6fc --- /dev/null +++ b/repository/shape/compile/pipeline/policy.go @@ -0,0 +1,28 @@ +package pipeline + +import dqlstmt "github.com/viant/datly/repository/shape/dql/statement" + +type Decision struct { + HasRead bool + HasExec bool + HasUnknown bool +} + +func Classify(statements dqlstmt.Statements) Decision { + var ret Decision + for _, stmt := range statements { + if stmt == nil { + continue + } + if stmt.Kind == dqlstmt.KindExec || stmt.Kind == dqlstmt.KindService { + ret.HasExec = true + continue + } + if stmt.Kind == dqlstmt.KindRead { + ret.HasRead = true + continue + } + ret.HasUnknown = true + } + return ret +} diff --git a/repository/shape/compile/pipeline/policy_test.go b/repository/shape/compile/pipeline/policy_test.go new file mode 100644 index 000000000..2ff7d3077 --- /dev/null +++ b/repository/shape/compile/pipeline/policy_test.go @@ -0,0 +1,36 @@ +package pipeline + +import ( + "testing" + + "github.com/stretchr/testify/assert" + dqlstmt "github.com/viant/datly/repository/shape/dql/statement" +) + +func TestClassify_ReadOnly(t *testing.T) { + decision := Classify(dqlstmt.New("SELECT id FROM orders")) + assert.True(t, decision.HasRead) + assert.False(t, decision.HasExec) + assert.False(t, decision.HasUnknown) +} + +func TestClassify_ExecOnly(t *testing.T) { + decision := Classify(dqlstmt.New("UPDATE orders SET id = 1")) + assert.False(t, decision.HasRead) + assert.True(t, decision.HasExec) + assert.False(t, decision.HasUnknown) +} + +func TestClassify_Mixed(t *testing.T) { + decision := Classify(dqlstmt.New("SELECT id FROM orders\nUPDATE orders SET id = 1")) + assert.True(t, decision.HasRead) + assert.True(t, decision.HasExec) + assert.False(t, decision.HasUnknown) +} + +func TestClassify_UnknownTemplateOnly(t *testing.T) { + decision := Classify(dqlstmt.New("$Foo.Bar($x)")) + assert.False(t, decision.HasRead) + assert.False(t, decision.HasExec) + assert.True(t, decision.HasUnknown) +} diff --git a/repository/shape/compile/pipeline/read.go b/repository/shape/compile/pipeline/read.go new file mode 100644 index 000000000..c665d154b --- /dev/null +++ b/repository/shape/compile/pipeline/read.go @@ -0,0 +1,205 @@ +package pipeline + +// read.go — SELECT compilation: parses DQL into a plan.View using +// multi-strategy parse with template-signal fallback. +// SQL normalization and token utilities live in read_normalize.go. + +import ( + "reflect" + "strings" + + dqlshape "github.com/viant/datly/repository/shape/dql/shape" + "github.com/viant/datly/repository/shape/plan" + "github.com/viant/sqlparser/query" +) + +// BuildRead compiles a SELECT DQL fragment into a plan.View. +// It applies multiple parse strategies and gracefully degrades to a +// loose (schema-less) view for template-driven SQL that cannot be fully parsed. +func BuildRead(sourceName, sqlText string) (*plan.View, []*dqlshape.Diagnostic, error) { + queryNode, parseDiag, parserSQL, err := resolveQueryNode(sqlText) + + // Template-driven SQL may legitimately fail strict parsing; treat as warning. + if (err != nil || parseDiag != nil) && hasTemplateSignals(sqlText) { + if parseDiag != nil { + parseDiag.Severity = dqlshape.SeverityWarning + } + return buildLooseRead(sourceName, sqlText), collectDiags(parseDiag), nil + } + + var diags []*dqlshape.Diagnostic + if parseDiag != nil { + diags = append(diags, parseDiag) + } + if err != nil { + return nil, diags, nil + } + + relations, relationDiags := ExtractJoinRelations(parserSQL, queryNode) + diags = append(diags, relationDiags...) + + name, table, inferErr := InferRoot(queryNode, sourceName) + if inferErr != nil { + return nil, nil, inferErr + } + fallback := SanitizeName(sourceName) + if name == fallback && table == fallback { + if derived := inferRootFromRelations(relations); derived != "" { + name = derived + table = derived + } + } + + fieldType, elementType, cardinality := InferProjectionType(queryNode) + if fieldType == nil || elementType == nil { + fieldType = reflect.TypeOf([]map[string]interface{}{}) + elementType = reflect.TypeOf(map[string]interface{}{}) + cardinality = "many" + } + view := &plan.View{ + Path: name, + Holder: name, + Name: name, + Mode: "SQLQuery", + Table: table, + SQL: sqlText, + Cardinality: cardinality, + FieldType: fieldType, + ElementType: elementType, + Relations: relations, + } + return view, diags, nil +} + +// resolveQueryNode attempts to parse sqlText into a query AST using up to +// three strategies: +// 1. Parse the normalised form. +// 2. If normalisation broke the SQL, fall back to the raw form. +// 3. If the parsed result is structurally incomplete, retry with the +// normalised form to pick up joins the raw parse missed. +// +// It returns the best node, any diagnostic, the effective SQL used, and any +// parse error. +func resolveQueryNode(sqlText string) (node *query.Select, diag *dqlshape.Diagnostic, effectiveSQL string, err error) { + parserSQL := normalizeParserSQL(sqlText) + node, diag, err = ParseSelectWithDiagnostic(parserSQL) + + // Strategy 2: normalisation may have broken the SQL; try raw form. + if err != nil && parserSQL != sqlText { + if rawNode, _, rawErr := ParseSelectWithDiagnostic(sqlText); rawErr == nil && isUsableQuery(rawNode) { + return rawNode, nil, sqlText, nil + } + } + + // Strategy 3: parsed OK but result is incomplete (no FROM or missing JOINs); + // retry with the normalised form. + if err == nil && needsFallbackParse(sqlText, node) { + fallbackSQL := normalizeParserSQL(sqlText) + if fallbackNode, _, fallbackErr := ParseSelectWithDiagnostic(fallbackSQL); fallbackErr == nil && isUsableQuery(fallbackNode) { + return fallbackNode, nil, fallbackSQL, nil + } + } + return node, diag, parserSQL, err +} + +func buildLooseRead(sourceName, sqlText string) *plan.View { + name, table := inferLooseRoot(sourceName, sqlText) + fieldType := reflect.TypeOf([]map[string]interface{}{}) + elementType := reflect.TypeOf(map[string]interface{}{}) + return &plan.View{ + Path: name, + Holder: name, + Name: name, + Mode: "SQLQuery", + Table: table, + SQL: sqlText, + Cardinality: "many", + FieldType: fieldType, + ElementType: elementType, + } +} + +func inferLooseRoot(sourceName, sqlText string) (string, string) { + name := SanitizeName(sourceName) + if name == "" { + name = "DQLView" + } + if table := extractSimpleFromTable(sqlText); table != "" { + return name, table + } + return name, name +} + +func hasTemplateSignals(sqlText string) bool { + lower := strings.ToLower(sqlText) + return strings.Contains(lower, "#if(") || strings.Contains(lower, "#elseif(") || strings.Contains(lower, "#else") || + strings.Contains(lower, "#end") || strings.Contains(lower, "${") || strings.Contains(lower, "$unsafe.") || + strings.Contains(lower, "$view.") || strings.Contains(lower, "$predicate.") +} + +func isUsableQuery(queryNode *query.Select) bool { + return queryNode != nil && queryNode.From.X != nil +} + +func needsFallbackParse(rawSQL string, queryNode *query.Select) bool { + if !isUsableQuery(queryNode) { + return true + } + lower := strings.ToLower(rawSQL) + if strings.Contains(lower, " join ") && len(queryNode.Joins) == 0 { + return true + } + return false +} + +func inferRootFromRelations(relations []*plan.Relation) string { + for _, relation := range relations { + if relation == nil { + continue + } + for _, link := range relation.On { + if link == nil { + continue + } + name := SanitizeName(link.ParentNamespace) + if name != "" { + return name + } + } + } + return "" +} + +func extractSimpleFromTable(sqlText string) string { + lower := strings.ToLower(sqlText) + for i := 0; i+4 <= len(lower); i++ { + if lower[i] != 'f' || !strings.HasPrefix(lower[i:], "from") { + continue + } + if i > 0 && isReadIdentifierPart(lower[i-1]) { + continue + } + j := skipReadSpaces(sqlText, i+4) + start := j + if start >= len(sqlText) || !isReadIdentifierStart(sqlText[start]) { + continue + } + j++ + for j < len(sqlText) && (isReadIdentifierPart(sqlText[j]) || sqlText[j] == '.' || sqlText[j] == '$') { + j++ + } + if start < j { + return strings.Trim(sqlText[start:j], "`\"") + } + } + return "" +} + +// collectDiags returns a single-element slice for a non-nil diagnostic, +// or nil otherwise. Used to avoid repeated nil checks at call sites. +func collectDiags(diag *dqlshape.Diagnostic) []*dqlshape.Diagnostic { + if diag == nil { + return nil + } + return []*dqlshape.Diagnostic{diag} +} diff --git a/repository/shape/compile/pipeline/read_normalize.go b/repository/shape/compile/pipeline/read_normalize.go new file mode 100644 index 000000000..6ff42af2f --- /dev/null +++ b/repository/shape/compile/pipeline/read_normalize.go @@ -0,0 +1,266 @@ +package pipeline + +// read_normalize.go — SQL normalization and template-token replacement used +// by BuildRead to produce parser-friendly SQL from raw DQL. + +import "strings" + +// normalizeParserSQL rewrites private(…) shorthands and template tokens into +// plain SQL that the parser can handle. +func normalizeParserSQL(sqlText string) string { + if sqlText == "" { + return sqlText + } + return rewritePrivateShorthand(replaceTemplateTokens(sqlText)) +} + +func rewritePrivateShorthand(input string) string { + var b strings.Builder + b.Grow(len(input)) + for i := 0; i < len(input); { + if !hasPrefixFold(input[i:], "private") { + b.WriteByte(input[i]) + i++ + continue + } + if i > 0 && isReadIdentifierPart(input[i-1]) { + b.WriteByte(input[i]) + i++ + continue + } + pos := i + len("private") + pos = skipReadSpaces(input, pos) + if pos >= len(input) || input[pos] != '(' { + b.WriteByte(input[i]) + i++ + continue + } + body, closeIdx, ok := readReadCallBody(input, pos) + if !ok { + b.WriteByte(input[i]) + i++ + continue + } + firstArg, ok := firstCallArg(body) + if !ok { + b.WriteByte(input[i]) + i++ + continue + } + b.WriteString(strings.TrimSpace(firstArg)) + i = closeIdx + 1 + } + return b.String() +} + +func hasPrefixFold(s, prefix string) bool { + if len(s) < len(prefix) { + return false + } + return strings.EqualFold(s[:len(prefix)], prefix) +} + +func firstCallArg(body string) (string, bool) { + depth := 0 + quote := byte(0) + for i := 0; i < len(body); i++ { + ch := body[i] + if quote != 0 { + if ch == '\\' && i+1 < len(body) { + i++ + continue + } + if ch == quote { + quote = 0 + } + continue + } + if ch == '\'' || ch == '"' { + quote = ch + continue + } + switch ch { + case '(': + depth++ + case ')': + if depth > 0 { + depth-- + } + case ',': + if depth == 0 { + arg := strings.TrimSpace(body[:i]) + return arg, arg != "" + } + } + } + arg := strings.TrimSpace(body) + return arg, arg != "" +} + +func replaceTemplateTokens(input string) string { + var b strings.Builder + b.Grow(len(input)) + for i := 0; i < len(input); { + if input[i] != '$' { + b.WriteByte(input[i]) + i++ + continue + } + if i+1 < len(input) && input[i+1] == '{' { + body, end, ok := readReadTemplateExpr(input, i+1) + if !ok { + b.WriteByte(input[i]) + i++ + continue + } + replacement, keep := normalizeTemplateExprBody(body) + if keep { + b.WriteString(input[i : end+1]) + } else { + b.WriteString(replacement) + } + i = end + 1 + continue + } + token, end, ok := readReadSelector(input, i) + if !ok { + b.WriteByte(input[i]) + i++ + continue + } + if strings.EqualFold(token, "$criteria.AppendBinding") { + pos := skipReadSpaces(input, end) + if pos < len(input) && input[pos] == '(' { + _, close, ok := readReadCallBody(input, pos) + if ok { + b.WriteByte('1') + i = close + 1 + continue + } + } + } + if isReadReservedToken(token) { + b.WriteString(token) + } else { + b.WriteByte('1') + } + i = end + } + return b.String() +} + +func normalizeTemplateExprBody(body string) (string, bool) { + trimmed := strings.TrimSpace(body) + if isReadReservedName(trimmed) { + return "", true + } + lower := strings.ToLower(trimmed) + if strings.Contains(lower, `build("where")`) || strings.Contains(lower, "build('where')") { + return " WHERE 1 ", false + } + if strings.Contains(lower, `build("and")`) || strings.Contains(lower, "build('and')") { + return " AND 1 ", false + } + return "1", false +} + +func readReadTemplateExpr(input string, openBrace int) (string, int, bool) { + if openBrace <= 0 || openBrace >= len(input) || input[openBrace] != '{' || input[openBrace-1] != '$' { + return "", -1, false + } + for i := openBrace + 1; i < len(input); i++ { + if input[i] == '}' { + return input[openBrace+1 : i], i, true + } + } + return "", -1, false +} + +func readReadSelector(input string, start int) (string, int, bool) { + if start < 0 || start >= len(input) || input[start] != '$' { + return "", start, false + } + i := start + 1 + if i >= len(input) || !isReadIdentifierStart(input[i]) { + return "", start, false + } + i++ + for i < len(input) && isReadIdentifierPart(input[i]) { + i++ + } + for i < len(input) && input[i] == '.' { + i++ + if i >= len(input) || !isReadIdentifierStart(input[i]) { + return "", start, false + } + i++ + for i < len(input) && isReadIdentifierPart(input[i]) { + i++ + } + } + return input[start:i], i, true +} + +func readReadCallBody(input string, openParen int) (string, int, bool) { + depth := 0 + quote := byte(0) + for i := openParen; i < len(input); i++ { + ch := input[i] + if quote != 0 { + if ch == '\\' && i+1 < len(input) { + i++ + continue + } + if ch == quote { + quote = 0 + } + continue + } + if ch == '\'' || ch == '"' { + quote = ch + continue + } + if ch == '(' { + depth++ + continue + } + if ch == ')' { + depth-- + if depth == 0 { + return input[openParen+1 : i], i, true + } + } + } + return "", -1, false +} + +func isReadReservedToken(token string) bool { + if len(token) > 0 && token[0] == '$' { + token = token[1:] + } + return isReadReservedName(token) +} + +func isReadReservedName(name string) bool { + return name == "sql.Insert" || name == "sql.Update" || name == "Nop" +} + +func skipReadSpaces(input string, index int) int { + for index < len(input) { + switch input[index] { + case ' ', '\t', '\n', '\r': + index++ + default: + return index + } + } + return index +} + +func isReadIdentifierStart(ch byte) bool { + return ch == '_' || (ch >= 'a' && ch <= 'z') || (ch >= 'A' && ch <= 'Z') +} + +func isReadIdentifierPart(ch byte) bool { + return isReadIdentifierStart(ch) || (ch >= '0' && ch <= '9') +} diff --git a/repository/shape/compile/pipeline/read_test.go b/repository/shape/compile/pipeline/read_test.go new file mode 100644 index 000000000..9d414beb8 --- /dev/null +++ b/repository/shape/compile/pipeline/read_test.go @@ -0,0 +1,73 @@ +package pipeline + +import ( + "strings" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/viant/sqlparser/expr" + "github.com/viant/sqlparser/query" +) + +func TestBuildRead(t *testing.T) { + view, diags, err := BuildRead("orders_report", "SELECT o.id, i.sku FROM orders o JOIN items i ON o.id = i.order_id") + require.NoError(t, err) + require.NotNil(t, view) + assert.Equal(t, "o", view.Name) + assert.Equal(t, "orders", view.Table) + assert.Equal(t, "many", view.Cardinality) + require.Len(t, view.Relations, 1) + assert.Equal(t, "i", view.Relations[0].Ref) + assert.Empty(t, diags) +} + +func TestBuildRead_SubqueryJoin_UsesParentNamespaceAsRoot(t *testing.T) { + sqlText := `SELECT session.* +FROM (SELECT * FROM session WHERE user_id = $criteria.AppendBinding($Unsafe.Jwt.UserID)) session +JOIN (SELECT * FROM session/attributes) attribute ON attribute.user_id = session.user_id` + view, _, err := BuildRead("system/session", sqlText) + require.NoError(t, err) + require.NotNil(t, view) + assert.Equal(t, "session", view.Name) + assert.Equal(t, "session", view.Table) + require.NotEmpty(t, view.Relations) + assert.Equal(t, "attribute", view.Relations[0].Ref) +} + +func TestNormalizeParserSQL(t *testing.T) { + input := "SELECT * FROM session WHERE user_id = $criteria.AppendBinding($Unsafe.Jwt.UserID) AND x = $Jwt.UserID" + actual := normalizeParserSQL(input) + assert.NotContains(t, actual, "$criteria.AppendBinding") + assert.NotContains(t, actual, "$Jwt.UserID") + assert.Contains(t, actual, "user_id = 1") +} + +func TestNormalizeParserSQL_VeltyBlockExpression(t *testing.T) { + input := `SELECT b.* FROM CI_BROWSER b ${predicate.Builder().CombineOr($predicate.FilterGroup(0, "AND")).Build("WHERE")} AND b.ARCHIVED = 0` + actual := normalizeParserSQL(input) + assert.NotContains(t, actual, "${predicate.Builder()") + assert.Contains(t, actual, "SELECT b.* FROM CI_BROWSER b WHERE 1 AND b.ARCHIVED = 0") +} + +func TestNormalizeParserSQL_PrivateShorthand(t *testing.T) { + input := `SELECT private(audience.FREQ_CAPPING) AS freq_capping FROM CI_AUDIENCE audience` + actual := normalizeParserSQL(input) + assert.NotContains(t, strings.ToLower(actual), "private(") + assert.Contains(t, actual, "SELECT audience.FREQ_CAPPING AS freq_capping FROM CI_AUDIENCE audience") +} + +func TestNeedsFallbackParse(t *testing.T) { + assert.True(t, needsFallbackParse("SELECT * FROM t JOIN x ON t.id = x.id", &query.Select{})) + assert.False(t, needsFallbackParse("SELECT * FROM t", &query.Select{From: query.From{X: expr.NewSelector("t")}})) +} + +func TestBuildRead_FallbackWhenInitialParseFails(t *testing.T) { + sqlText := `SELECT b.* FROM CI_BROWSER b ${predicate.Builder().CombineOr($predicate.FilterGroup(0, "AND")).Build("WHERE")} AND b.ARCHIVED = 0` + view, diags, err := BuildRead("browser", sqlText) + require.NoError(t, err) + require.NotNil(t, view) + assert.Equal(t, "b", view.Name) + assert.Equal(t, "CI_BROWSER", view.Table) + assert.Empty(t, diags) +} diff --git a/repository/shape/compile/pipeline/relation.go b/repository/shape/compile/pipeline/relation.go new file mode 100644 index 000000000..dc94b8955 --- /dev/null +++ b/repository/shape/compile/pipeline/relation.go @@ -0,0 +1,390 @@ +package pipeline + +import ( + "fmt" + "strings" + + dqldiag "github.com/viant/datly/repository/shape/dql/diag" + dqlshape "github.com/viant/datly/repository/shape/dql/shape" + "github.com/viant/datly/repository/shape/plan" + "github.com/viant/sqlparser" + "github.com/viant/sqlparser/expr" + "github.com/viant/sqlparser/node" + "github.com/viant/sqlparser/query" +) + +func ExtractJoinRelations(raw string, queryNode *query.Select) ([]*plan.Relation, []*dqlshape.Diagnostic) { + if queryNode == nil || len(queryNode.Joins) == 0 { + return nil, nil + } + rootAlias := rootNamespace(queryNode) + var relations []*plan.Relation + var diagnostics []*dqlshape.Diagnostic + + for idx, join := range queryNode.Joins { + if join == nil { + continue + } + offset := relationOffset(raw, join) + span := pointSpan(raw, offset) + ref, table := relationRef(join, idx+1) + relation := &plan.Relation{ + Name: ref, + Holder: ExportedName(ref), + Ref: ref, + Table: table, + Kind: strings.TrimSpace(join.Kind), + Raw: strings.TrimSpace(join.Raw), + } + if relation.Holder == "" { + relation.Holder = fmt.Sprintf("Rel%d", idx+1) + } + if join.On == nil || join.On.X == nil { + diagnostics = append(diagnostics, &dqlshape.Diagnostic{ + Code: dqldiag.CodeRelMissingON, + Severity: dqlshape.SeverityWarning, + Message: "join is missing ON condition", + Hint: "use explicit ON condition to derive relation links", + Span: span, + }) + relation.Warnings = append(relation.Warnings, "missing ON condition") + relations = append(relations, relation) + continue + } + pairs := collectJoinPairs(join.On.X) + if len(pairs) == 0 { + onExpr := strings.TrimSpace(sqlparser.Stringify(join.On.X)) + if shouldFallbackToRawJoinPairs(onExpr) { + pairs = collectJoinPairsFromRaw(onExpr) + } + } + if len(pairs) == 0 { + diagnostics = append(diagnostics, &dqlshape.Diagnostic{ + Code: dqldiag.CodeRelUnsupported, + Severity: dqlshape.SeverityWarning, + Message: "join ON condition could not be translated into relation links", + Hint: "use equality predicates between concrete columns, e.g. a.id = b.ref_id", + Span: span, + }) + relation.Warnings = append(relation.Warnings, "unsupported ON predicate") + relations = append(relations, relation) + continue + } + for _, pair := range pairs { + link, warning := orientJoinPair(pair, rootAlias, ref) + if warning != "" { + diagnostics = append(diagnostics, &dqlshape.Diagnostic{ + Code: dqldiag.CodeRelAmbiguous, + Severity: dqlshape.SeverityWarning, + Message: warning, + Hint: "use explicit aliases so one side belongs to root and the other to joined table", + Span: span, + }) + relation.Warnings = append(relation.Warnings, warning) + } + if link == nil { + continue + } + relation.On = append(relation.On, link) + } + if len(relation.On) == 0 { + diagnostics = append(diagnostics, &dqlshape.Diagnostic{ + Code: dqldiag.CodeRelNoLinks, + Severity: dqlshape.SeverityWarning, + Message: "join ON condition does not expose extractable column links", + Hint: "ensure both sides of '=' are concrete column references", + Span: span, + }) + relation.Warnings = append(relation.Warnings, "no extractable links") + } + relations = append(relations, relation) + } + return relations, diagnostics +} + +func collectJoinPairsFromRaw(input string) []joinPair { + input = strings.TrimSpace(input) + if input == "" { + return nil + } + var ( + result []joinPair + i int + ) + for i < len(input) { + left, next, ok := parseRelationSelector(input, i) + if !ok { + i++ + continue + } + j := skipRelationSpaces(input, next) + if j >= len(input) || input[j] != '=' { + i = next + continue + } + right, end, ok := parseRelationSelector(input, j+1) + if !ok { + i = j + 1 + continue + } + if strings.TrimSpace(left) == "" || strings.TrimSpace(right) == "" { + i = end + continue + } + result = append(result, joinPair{left: left, right: right}) + i = end + } + return result +} + +func parseRelationSelector(input string, start int) (string, int, bool) { + i := skipRelationSpaces(input, start) + nsStart := i + if nsStart >= len(input) || !isRelationIdentifierStart(input[nsStart]) { + return "", start, false + } + i++ + for i < len(input) && isRelationIdentifierPart(input[i]) { + i++ + } + ns := input[nsStart:i] + i = skipRelationSpaces(input, i) + if i >= len(input) || input[i] != '.' { + return "", start, false + } + i++ + i = skipRelationSpaces(input, i) + colStart := i + if colStart >= len(input) || !isRelationIdentifierStart(input[colStart]) { + return "", start, false + } + i++ + for i < len(input) && isRelationIdentifierPart(input[i]) { + i++ + } + col := input[colStart:i] + return strings.TrimSpace(ns) + "." + strings.TrimSpace(col), i, true +} + +func skipRelationSpaces(input string, index int) int { + for index < len(input) { + switch input[index] { + case ' ', '\t', '\n', '\r': + index++ + default: + return index + } + } + return index +} + +func isRelationIdentifierStart(ch byte) bool { + return ch == '_' || (ch >= 'a' && ch <= 'z') || (ch >= 'A' && ch <= 'Z') +} + +func isRelationIdentifierPart(ch byte) bool { + return isRelationIdentifierStart(ch) || (ch >= '0' && ch <= '9') +} + +func shouldFallbackToRawJoinPairs(input string) bool { + input = strings.TrimSpace(strings.ToLower(input)) + if input == "" { + return false + } + // Restrict raw fallback to simple selector equality text to avoid brittle extraction + // for quoted identifiers, function calls, casts, and richer predicates. + bannedFragments := []string{ + "`", "\"", "'", "(", ")", "::", " collate ", " case ", " when ", " then ", " else ", " end ", + " coalesce", " cast", " concat", " substr", " lower", " upper", " trim", + } + for _, fragment := range bannedFragments { + if strings.Contains(input, fragment) { + return false + } + } + return true +} + +type joinPair struct { + left string + right string +} + +func collectJoinPairs(n node.Node) []joinPair { + switch actual := n.(type) { + case *expr.Binary: + op := strings.ToUpper(strings.TrimSpace(actual.Op)) + if op == "AND" || op == "OR" { + left := collectJoinPairs(actual.X) + right := collectJoinPairs(actual.Y) + return append(left, right...) + } + if op != "=" { + return nil + } + left := selectorName(actual.X) + right := selectorName(actual.Y) + if left == "" || right == "" { + return nil + } + return []joinPair{{left: left, right: right}} + case *expr.Parenthesis: + return collectJoinPairs(actual.X) + default: + return nil + } +} + +func selectorName(n node.Node) string { + switch actual := n.(type) { + case *expr.Selector: + return strings.TrimSpace(sqlparser.Stringify(actual)) + case *expr.Parenthesis: + return selectorName(actual.X) + default: + return "" + } +} + +func orientJoinPair(pair joinPair, rootAlias, refAlias string) (*plan.RelationLink, string) { + leftNS, leftCol := splitSelector(pair.left) + rightNS, rightCol := splitSelector(pair.right) + if leftCol == "" || rightCol == "" { + return nil, "" + } + switch { + case leftNS == rootAlias && (rightNS == refAlias || rightNS == ""): + return &plan.RelationLink{ + ParentNamespace: leftNS, + ParentColumn: leftCol, + RefNamespace: firstNonEmpty(rightNS, refAlias), + RefColumn: rightCol, + Expression: pair.left + "=" + pair.right, + }, "" + case rightNS == rootAlias && (leftNS == refAlias || leftNS == ""): + return &plan.RelationLink{ + ParentNamespace: rightNS, + ParentColumn: rightCol, + RefNamespace: firstNonEmpty(leftNS, refAlias), + RefColumn: leftCol, + Expression: pair.left + "=" + pair.right, + }, "" + case leftNS == "" && rightNS == "": + return &plan.RelationLink{ + ParentNamespace: rootAlias, + ParentColumn: leftCol, + RefNamespace: refAlias, + RefColumn: rightCol, + Expression: pair.left + "=" + pair.right, + }, "join columns lack namespaces, relation orientation was inferred" + case leftNS == refAlias: + parentNS := rightNS + if parentNS == "" { + parentNS = rootAlias + } + return &plan.RelationLink{ + ParentNamespace: parentNS, + ParentColumn: rightCol, + RefNamespace: leftNS, + RefColumn: leftCol, + Expression: pair.left + "=" + pair.right, + }, "" + case rightNS == refAlias: + parentNS := leftNS + if parentNS == "" { + parentNS = rootAlias + } + return &plan.RelationLink{ + ParentNamespace: parentNS, + ParentColumn: leftCol, + RefNamespace: rightNS, + RefColumn: rightCol, + Expression: pair.left + "=" + pair.right, + }, "" + default: + return nil, fmt.Sprintf("ambiguous join link %q cannot be oriented between root=%q and ref=%q", pair.left+"="+pair.right, rootAlias, refAlias) + } +} + +func relationOffset(raw string, join *query.Join) int { + if strings.TrimSpace(raw) == "" { + return 0 + } + if join != nil && join.On != nil && join.On.X != nil { + if onExpr := strings.TrimSpace(sqlparser.Stringify(join.On.X)); onExpr != "" { + if idx := strings.Index(strings.ToLower(raw), strings.ToLower(onExpr)); idx >= 0 { + return idx + } + } + } + if join != nil && strings.TrimSpace(join.Raw) != "" { + if idx := strings.Index(strings.ToLower(raw), strings.ToLower(strings.TrimSpace(join.Raw))); idx >= 0 { + return idx + } + } + return 0 +} + +func rootNamespace(queryNode *query.Select) string { + if queryNode == nil { + return "" + } + if alias := strings.TrimSpace(queryNode.From.Alias); alias != "" { + return alias + } + if queryNode.From.X == nil { + return "" + } + root := strings.TrimSpace(sqlparser.Stringify(queryNode.From.X)) + root = strings.Trim(root, "`\"") + if root == "" { + return "" + } + if idx := strings.LastIndex(root, "."); idx != -1 { + root = root[idx+1:] + } + return root +} + +func relationRef(join *query.Join, ordinal int) (string, string) { + if join == nil { + return fmt.Sprintf("join_%d", ordinal), "" + } + ref := strings.TrimSpace(join.Alias) + table := "" + if join.With != nil { + table = strings.TrimSpace(sqlparser.Stringify(join.With)) + } + if ref == "" { + ref = table + if idx := strings.LastIndex(ref, "."); idx != -1 { + ref = ref[idx+1:] + } + } + ref = SanitizeName(strings.Trim(ref, "`\"")) + if ref == "" { + ref = fmt.Sprintf("join_%d", ordinal) + } + return ref, table +} + +func splitSelector(selector string) (string, string) { + selector = strings.TrimSpace(selector) + if selector == "" { + return "", "" + } + selector = strings.Trim(selector, "`\"") + if idx := strings.Index(selector, "."); idx != -1 { + return strings.Trim(selector[:idx], "`\""), strings.Trim(selector[idx+1:], "`\"") + } + return "", selector +} + +func firstNonEmpty(values ...string) string { + for _, value := range values { + if strings.TrimSpace(value) != "" { + return strings.TrimSpace(value) + } + } + return "" +} diff --git a/repository/shape/compile/pipeline/relation_test.go b/repository/shape/compile/pipeline/relation_test.go new file mode 100644 index 000000000..62f5c9d3a --- /dev/null +++ b/repository/shape/compile/pipeline/relation_test.go @@ -0,0 +1,89 @@ +package pipeline + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + dqldiag "github.com/viant/datly/repository/shape/dql/diag" + "github.com/viant/sqlparser" +) + +func TestExtractJoinRelations(t *testing.T) { + sqlText := "SELECT o.id FROM orders o JOIN order_items i ON o.id = i.order_id" + queryNode, err := sqlparser.ParseQuery(sqlText) + require.NoError(t, err) + relations, diags := ExtractJoinRelations(sqlText, queryNode) + require.Len(t, relations, 1) + assert.Equal(t, "i", relations[0].Ref) + require.Len(t, relations[0].On, 1) + assert.Equal(t, "id", relations[0].On[0].ParentColumn) + assert.Equal(t, "order_id", relations[0].On[0].RefColumn) + assert.Empty(t, diags) +} + +func TestExtractJoinRelations_UnsupportedPredicate(t *testing.T) { + sqlText := "SELECT o.id FROM orders o JOIN order_items i ON o.id > i.order_id" + queryNode, err := sqlparser.ParseQuery(sqlText) + require.NoError(t, err) + _, diags := ExtractJoinRelations(sqlText, queryNode) + require.NotEmpty(t, diags) + assert.Equal(t, dqldiag.CodeRelUnsupported, diags[0].Code) +} + +func TestExtractJoinRelations_WithAndLiteral(t *testing.T) { + sqlText := "SELECT t.id FROM taxonomy t LEFT JOIN provider p ON p.id = t.provider_id AND 1=1" + queryNode, err := sqlparser.ParseQuery(sqlText) + require.NoError(t, err) + relations, diags := ExtractJoinRelations(sqlText, queryNode) + require.Len(t, relations, 1) + require.Len(t, relations[0].On, 1) + assert.Equal(t, "provider_id", relations[0].On[0].ParentColumn) + assert.Equal(t, "id", relations[0].On[0].RefColumn) + assert.Empty(t, diags) +} + +func TestExtractJoinRelations_NonRootParentChain(t *testing.T) { + sqlText := "SELECT sl.id FROM site_list sl JOIN site_list_match m ON m.site_list_id = sl.id JOIN ci_site s ON s.id = m.site_id JOIN ci_publisher p ON p.id = s.publisher_id" + queryNode, err := sqlparser.ParseQuery(sqlText) + require.NoError(t, err) + relations, diags := ExtractJoinRelations(sqlText, queryNode) + require.Len(t, relations, 3) + + require.Len(t, relations[0].On, 1) + assert.Equal(t, "sl", relations[0].On[0].ParentNamespace) + assert.Equal(t, "id", relations[0].On[0].ParentColumn) + assert.Equal(t, "m", relations[0].On[0].RefNamespace) + assert.Equal(t, "site_list_id", relations[0].On[0].RefColumn) + + require.Len(t, relations[1].On, 1) + assert.Equal(t, "m", relations[1].On[0].ParentNamespace) + assert.Equal(t, "site_id", relations[1].On[0].ParentColumn) + assert.Equal(t, "s", relations[1].On[0].RefNamespace) + assert.Equal(t, "id", relations[1].On[0].RefColumn) + + require.Len(t, relations[2].On, 1) + assert.Equal(t, "s", relations[2].On[0].ParentNamespace) + assert.Equal(t, "publisher_id", relations[2].On[0].ParentColumn) + assert.Equal(t, "p", relations[2].On[0].RefNamespace) + assert.Equal(t, "id", relations[2].On[0].RefColumn) + assert.Empty(t, diags) +} + +func TestExtractJoinRelations_DoesNotFallbackForComplexRawPredicate(t *testing.T) { + sqlText := "SELECT o.id FROM orders o JOIN order_items i ON COALESCE(o.id, 0) = i.order_id" + queryNode, err := sqlparser.ParseQuery(sqlText) + require.NoError(t, err) + relations, diags := ExtractJoinRelations(sqlText, queryNode) + require.Len(t, relations, 1) + assert.Empty(t, relations[0].On) + require.NotEmpty(t, diags) + assert.Equal(t, dqldiag.CodeRelUnsupported, diags[0].Code) +} + +func TestShouldFallbackToRawJoinPairs(t *testing.T) { + assert.True(t, shouldFallbackToRawJoinPairs("o.id = i.order_id")) + assert.False(t, shouldFallbackToRawJoinPairs("COALESCE(o.id, 0) = i.order_id")) + assert.False(t, shouldFallbackToRawJoinPairs("`o`.`id` = `i`.`order_id`")) + assert.False(t, shouldFallbackToRawJoinPairs(`"o"."id" = "i"."order_id"`)) +} diff --git a/repository/shape/compile/pipeline/table.go b/repository/shape/compile/pipeline/table.go new file mode 100644 index 000000000..5888aeaec --- /dev/null +++ b/repository/shape/compile/pipeline/table.go @@ -0,0 +1,21 @@ +package pipeline + +import "strings" + +// InferTableFromSQL infers root table from SQL text using parser-first strategy. +func InferTableFromSQL(sqlText string) string { + sqlText = strings.TrimSpace(sqlText) + if sqlText == "" { + return "" + } + normalized := normalizeParserSQL(sqlText) + queryNode, _, err := ParseSelectWithDiagnostic(normalized) + if err != nil || queryNode == nil { + return "" + } + _, table, err := InferRoot(queryNode, "") + if err != nil { + return "" + } + return strings.TrimSpace(strings.Trim(table, "`\"")) +} diff --git a/repository/shape/compile/policy.go b/repository/shape/compile/policy.go new file mode 100644 index 000000000..1bd02ca63 --- /dev/null +++ b/repository/shape/compile/policy.go @@ -0,0 +1,48 @@ +package compile + +import ( + "strings" + + dqldiag "github.com/viant/datly/repository/shape/dql/diag" + dqlshape "github.com/viant/datly/repository/shape/dql/shape" +) + +func hasEscalationWarnings(diags []*dqlshape.Diagnostic) bool { + for _, item := range diags { + if item == nil { + continue + } + if item.Severity != dqlshape.SeverityWarning { + continue + } + if strings.HasPrefix(item.Code, dqldiag.PrefixRel) || strings.HasPrefix(item.Code, dqldiag.PrefixSQLI) { + return true + } + } + return false +} + +func hasErrorDiagnostics(diags []*dqlshape.Diagnostic) bool { + for _, item := range diags { + if item == nil { + continue + } + if item.Severity == dqlshape.SeverityError { + return true + } + } + return false +} + +func filterEscalationDiagnostics(diags []*dqlshape.Diagnostic) []*dqlshape.Diagnostic { + var result []*dqlshape.Diagnostic + for _, item := range diags { + if item == nil { + continue + } + if strings.HasPrefix(item.Code, dqldiag.PrefixRel) || strings.HasPrefix(item.Code, dqldiag.PrefixSQLI) { + result = append(result, item) + } + } + return result +} diff --git a/repository/shape/compile/policy_test.go b/repository/shape/compile/policy_test.go new file mode 100644 index 000000000..a15e9f882 --- /dev/null +++ b/repository/shape/compile/policy_test.go @@ -0,0 +1,40 @@ +package compile + +import ( + "testing" + + "github.com/stretchr/testify/assert" + dqldiag "github.com/viant/datly/repository/shape/dql/diag" + dqlshape "github.com/viant/datly/repository/shape/dql/shape" +) + +func TestPolicy_HasEscalationWarnings(t *testing.T) { + diags := []*dqlshape.Diagnostic{ + {Code: dqldiag.CodeRelAmbiguous, Severity: dqlshape.SeverityWarning}, + } + assert.True(t, hasEscalationWarnings(diags)) + assert.False(t, hasEscalationWarnings([]*dqlshape.Diagnostic{ + {Code: dqldiag.CodeViewMissingSQL, Severity: dqlshape.SeverityWarning}, + })) +} + +func TestPolicy_HasErrorDiagnostics(t *testing.T) { + assert.True(t, hasErrorDiagnostics([]*dqlshape.Diagnostic{ + {Code: dqldiag.CodeParseSyntax, Severity: dqlshape.SeverityError}, + })) + assert.False(t, hasErrorDiagnostics([]*dqlshape.Diagnostic{ + {Code: dqldiag.CodeRelAmbiguous, Severity: dqlshape.SeverityWarning}, + })) +} + +func TestPolicy_FilterEscalationDiagnostics(t *testing.T) { + diags := []*dqlshape.Diagnostic{ + {Code: dqldiag.CodeViewMissingSQL, Severity: dqlshape.SeverityWarning}, + {Code: dqldiag.CodeSQLIRawSelector, Severity: dqlshape.SeverityWarning}, + {Code: dqldiag.CodeRelNoLinks, Severity: dqlshape.SeverityWarning}, + } + filtered := filterEscalationDiagnostics(diags) + assert.Len(t, filtered, 2) + assert.Equal(t, dqldiag.CodeSQLIRawSelector, filtered[0].Code) + assert.Equal(t, dqldiag.CodeRelNoLinks, filtered[1].Code) +} diff --git a/repository/shape/compile/preprocess_handler.go b/repository/shape/compile/preprocess_handler.go new file mode 100644 index 000000000..4ec1e9401 --- /dev/null +++ b/repository/shape/compile/preprocess_handler.go @@ -0,0 +1,125 @@ +package compile + +import ( + "os" + "path/filepath" + "strings" + + "github.com/viant/datly/repository/shape" + "github.com/viant/datly/repository/shape/compile/pipeline" + dqlpre "github.com/viant/datly/repository/shape/dql/preprocess" + dqlstmt "github.com/viant/datly/repository/shape/dql/statement" +) + +type handlerPreprocessResult struct { + Pre *dqlpre.Result + Statements dqlstmt.Statements + Decision pipeline.Decision + EffectiveSource *shape.Source +} + +func buildHandlerIfNeeded(source *shape.Source, pre *dqlpre.Result, statements dqlstmt.Statements, decision pipeline.Decision, layout compilePathLayout) *handlerPreprocessResult { + ret := &handlerPreprocessResult{ + Pre: pre, + Statements: statements, + Decision: decision, + EffectiveSource: source, + } + if source == nil { + return ret + } + unknownOnly := decision.HasUnknown && !decision.HasRead && !decision.HasExec + if !unknownOnly && !isHandlerSignal(source) { + return ret + } + if buildGeneratedFallbackIfNeeded(ret, source, layout) { + return ret + } + return ret +} + +func buildGeneratedFallbackIfNeeded(ret *handlerPreprocessResult, source *shape.Source, layout compilePathLayout) bool { + if ret == nil || source == nil { + return false + } + _ = layout + generated := strings.TrimSpace(resolveGeneratedCompanionDQL(source)) + if generated == "" { + return false + } + candidate := dqlpre.Prepare(generated) + if strings.TrimSpace(candidate.SQL) == "" { + return false + } + candidateStatements := dqlstmt.New(candidate.SQL) + candidateDecision := pipeline.Classify(candidateStatements) + if !candidateDecision.HasRead && !candidateDecision.HasExec { + return false + } + ret.Pre = candidate + ret.Statements = candidateStatements + ret.Decision = candidateDecision + return true +} + +// buildHandlerFromContractIfNeeded is kept as a legacy no-op shim for tests +// and callers migrated to buildHandlerIfNeeded/buildGeneratedFallbackIfNeeded. +func buildHandlerFromContractIfNeeded(_ *handlerPreprocessResult, _ *shape.Source, _ compilePathLayout) bool { + return false +} + +func resolveGeneratedLegacySource(source *shape.Source) *shape.Source { + if source == nil || strings.TrimSpace(source.Path) == "" { + return nil + } + path := filepath.Clean(source.Path) + normalized := filepath.ToSlash(path) + genIdx := strings.Index(normalized, "/gen/") + if genIdx == -1 { + return nil + } + prefix := normalized[:genIdx] + suffix := strings.TrimPrefix(normalized[genIdx+len("/gen/"):], "/") + parts := strings.Split(suffix, "/") + if len(parts) < 2 { + return nil + } + fileName := parts[len(parts)-1] + stem := strings.TrimSuffix(fileName, filepath.Ext(fileName)) + candidates := []string{ + filepath.FromSlash(prefix + "/" + fileName), + filepath.FromSlash(prefix + "/" + stem + ".sql"), + filepath.FromSlash(prefix + "/" + stem + ".dql"), + } + for _, candidate := range candidates { + data, err := os.ReadFile(candidate) + if err != nil { + continue + } + clone := *source + clone.Path = candidate + clone.DQL = string(data) + return &clone + } + return nil +} + +func isHandlerSignal(source *shape.Source) bool { + if source == nil { + return false + } + settings := extractRuleSettings(source, nil) + if settings != nil { + if strings.TrimSpace(settings.Type) != "" { + return true + } + if method := strings.TrimSpace(strings.ToUpper(settings.Method)); method != "" && method != "GET" { + return true + } + if strings.Contains(strings.ToLower(strings.TrimSpace(settings.URI)), "/proxy") { + return true + } + } + raw := strings.ToLower(strings.TrimSpace(source.DQL)) + return strings.Contains(raw, "$nop(") || strings.Contains(raw, "$proxy(") +} diff --git a/repository/shape/compile/preprocess_handler_test.go b/repository/shape/compile/preprocess_handler_test.go new file mode 100644 index 000000000..1f8c8d08d --- /dev/null +++ b/repository/shape/compile/preprocess_handler_test.go @@ -0,0 +1,103 @@ +package compile + +import ( + "os" + "path/filepath" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/viant/datly/repository/shape" + "github.com/viant/datly/repository/shape/compile/pipeline" + dqlpre "github.com/viant/datly/repository/shape/dql/preprocess" + dqlstmt "github.com/viant/datly/repository/shape/dql/statement" +) + +func TestIsHandlerSignal(t *testing.T) { + assert.True(t, isHandlerSignal(&shape.Source{DQL: `/* {"Type":"campaign/patch.Handler"} */`})) + assert.True(t, isHandlerSignal(&shape.Source{DQL: `$Nop($Data)`})) + assert.True(t, isHandlerSignal(&shape.Source{DQL: `$Proxy($Data)`})) + assert.False(t, isHandlerSignal(&shape.Source{DQL: `SELECT id FROM proxy_audit`})) + assert.False(t, isHandlerSignal(&shape.Source{DQL: `/* proxy disabled */ SELECT 1`})) + assert.False(t, isHandlerSignal(&shape.Source{DQL: `SELECT 1`})) +} + +func TestBuildHandlerFromContractIfNeeded_Disabled(t *testing.T) { + tempDir := t.TempDir() + sourcePath := filepath.Join(tempDir, "dql", "platform", "campaign", "post.dql") + require.NoError(t, os.MkdirAll(filepath.Dir(sourcePath), 0o755)) + dql := `/* {"Type":"campaign/patch.Handler","Connector":"ci_ads"} */` + require.NoError(t, os.WriteFile(sourcePath, []byte(dql), 0o644)) + + source := &shape.Source{Path: sourcePath, DQL: dql} + pre := dqlpre.Prepare(source.DQL) + statements := dqlstmt.New(pre.SQL) + decision := pipeline.Classify(statements) + result := &handlerPreprocessResult{Pre: pre, Statements: statements, Decision: decision, EffectiveSource: source} + applied := buildHandlerFromContractIfNeeded(result, source, defaultCompilePathLayout()) + require.False(t, applied) + require.NotNil(t, result) +} + +func TestBuildGeneratedFallbackIfNeeded_GeneratedCompanion(t *testing.T) { + tempDir := t.TempDir() + dqlPath := filepath.Join(tempDir, "platform", "adorder", "patch.dql") + require.NoError(t, os.MkdirAll(filepath.Join(filepath.Dir(dqlPath), "gen", "adorder"), 0o755)) + require.NoError(t, os.WriteFile(filepath.Join(filepath.Dir(dqlPath), "gen", "adorder", "patch.dql"), []byte("SELECT o.id FROM ORDERS o"), 0o644)) + source := &shape.Source{ + Name: "patch", + Path: dqlPath, + DQL: `/* {"Type":"adorder/patch.Handler"} */`, + } + pre := dqlpre.Prepare(source.DQL) + statements := dqlstmt.New(pre.SQL) + decision := pipeline.Classify(statements) + result := &handlerPreprocessResult{Pre: pre, Statements: statements, Decision: decision, EffectiveSource: source} + applied := buildGeneratedFallbackIfNeeded(result, source, defaultCompilePathLayout()) + require.True(t, applied) + require.NotNil(t, result) + assert.Contains(t, result.Pre.SQL, "SELECT o.id FROM ORDERS o") + assert.True(t, result.Decision.HasRead) +} + +func TestResolveGeneratedLegacySource(t *testing.T) { + tempDir := t.TempDir() + genPath := filepath.Join(tempDir, "dql", "system", "session", "gen", "session", "patch.dql") + require.NoError(t, os.MkdirAll(filepath.Dir(genPath), 0o755)) + require.NoError(t, os.WriteFile(genPath, []byte(`/* {"Method":"PATCH","URI":"/v1/api/system/session"} */`), 0o644)) + legacySQL := filepath.Join(tempDir, "dql", "system", "session", "patch.sql") + require.NoError(t, os.MkdirAll(filepath.Dir(legacySQL), 0o755)) + require.NoError(t, os.WriteFile(legacySQL, []byte(`/* {"Type":"session/patch.Handler"} */`), 0o644)) + + source := &shape.Source{Path: genPath, DQL: `/* {"Method":"PATCH","URI":"/v1/api/system/session"} */`} + actual := resolveGeneratedLegacySource(source) + require.NotNil(t, actual) + assert.Equal(t, legacySQL, actual.Path) + assert.Contains(t, actual.DQL, `"Type":"session/patch.Handler"`) +} + +func TestBuildGeneratedFallbackIfNeeded_NoGeneratedCompanionWithoutTypeHeader(t *testing.T) { + tempDir := t.TempDir() + genPath := filepath.Join(tempDir, "dql", "system", "session", "gen", "session", "patch.dql") + require.NoError(t, os.MkdirAll(filepath.Dir(genPath), 0o755)) + require.NoError(t, os.WriteFile(genPath, []byte(`/* {"Method":"PATCH","URI":"/v1/api/system/session"} */`), 0o644)) + + source := &shape.Source{Path: genPath, DQL: `/* {"Method":"PATCH","URI":"/v1/api/system/session"} */`} + pre := dqlpre.Prepare(source.DQL) + statements := dqlstmt.New(pre.SQL) + decision := pipeline.Classify(statements) + result := &handlerPreprocessResult{Pre: pre, Statements: statements, Decision: decision, EffectiveSource: source} + applied := buildGeneratedFallbackIfNeeded(result, source, defaultCompilePathLayout()) + require.False(t, applied) + require.NotNil(t, result) +} + +func TestBuildGeneratedFallbackIfNeeded_NoGeneratedCompanion(t *testing.T) { + source := &shape.Source{Path: filepath.Join(t.TempDir(), "dql", "x", "y", "z.dql"), DQL: `SELECT 1`} + pre := dqlpre.Prepare(source.DQL) + statements := dqlstmt.New(pre.SQL) + decision := pipeline.Classify(statements) + result := &handlerPreprocessResult{Pre: pre, Statements: statements, Decision: decision, EffectiveSource: source} + applied := buildGeneratedFallbackIfNeeded(result, source, defaultCompilePathLayout()) + assert.False(t, applied) +} diff --git a/repository/shape/compile/resolver.go b/repository/shape/compile/resolver.go new file mode 100644 index 000000000..6abf0b50f --- /dev/null +++ b/repository/shape/compile/resolver.go @@ -0,0 +1,96 @@ +package compile + +import ( + "context" + "net/http" + "strings" + + "github.com/viant/datly/repository/contract/signature" + "github.com/viant/datly/repository/shape/plan" +) + +// ComponentContract represents resolved component contract metadata. +type ComponentContract struct { + RouteKey string + Method string + URI string + OutputType string + Types []*plan.Type +} + +// ComponentResolver resolves component contract metadata for a route key. +type ComponentResolver interface { + ResolveContract(ctx context.Context, routeKey string) (*ComponentContract, error) +} + +// SignatureResolver adapts repository/contract/signature service +// to compile-time component contract resolution. +type SignatureResolver struct { + service *signature.Service +} + +// NewSignatureResolver creates signature-backed component resolver. +func NewSignatureResolver(ctx context.Context, apiPrefix, routesURL string) (*SignatureResolver, error) { + srv, err := signature.New(ctx, apiPrefix, routesURL) + if err != nil { + return nil, err + } + return &SignatureResolver{service: srv}, nil +} + +// ResolveContract resolves component contract by route key. +func (s *SignatureResolver) ResolveContract(_ context.Context, routeKey string) (*ComponentContract, error) { + method, uri := splitRouteKey(routeKey) + sig, err := s.service.Signature(method, uri) + if err != nil { + return nil, err + } + ret := &ComponentContract{ + RouteKey: normalizeRouteKey(method, uri), + Method: method, + URI: normalizeURI(uri), + } + if sig.Output != nil { + if dataType := strings.TrimSpace(sig.Output.DataType); dataType != "" { + ret.OutputType = dataType + } else if name := strings.TrimSpace(sig.Output.Name); name != "" { + name = strings.Trim(name, "*") + if name != "" { + ret.OutputType = "*" + name + } + } + } + for _, item := range sig.Types { + if item == nil { + continue + } + ret.Types = append(ret.Types, &plan.Type{ + Name: strings.TrimSpace(item.Name), + Alias: strings.TrimSpace(item.Alias), + DataType: strings.TrimSpace(item.DataType), + Cardinality: strings.TrimSpace(string(item.Cardinality)), + Package: strings.TrimSpace(item.Package), + ModulePath: strings.TrimSpace(item.ModulePath), + }) + } + return ret, nil +} + +func splitRouteKey(routeKey string) (string, string) { + routeKey = strings.TrimSpace(routeKey) + if routeKey == "" { + return http.MethodGet, "/" + } + if idx := strings.Index(routeKey, ":"); idx != -1 { + method := strings.ToUpper(strings.TrimSpace(routeKey[:idx])) + uri := strings.TrimSpace(routeKey[idx+1:]) + if method == "" { + method = http.MethodGet + } + if uri == "" { + uri = "/" + } + return method, uri + } + return http.MethodGet, routeKey +} diff --git a/repository/shape/compile/route_index.go b/repository/shape/compile/route_index.go new file mode 100644 index 000000000..30691e27a --- /dev/null +++ b/repository/shape/compile/route_index.go @@ -0,0 +1,231 @@ +package compile + +import ( + "fmt" + "net/http" + "os" + "path/filepath" + "sort" + "strings" + + "github.com/viant/datly/repository/shape" + dqlpre "github.com/viant/datly/repository/shape/dql/preprocess" +) + +// RouteIndexEntry maps one source DQL file to one concrete method+URI route key. +type RouteIndexEntry struct { + RouteKey string + Method string + URI string + SourcePath string + Namespace string +} + +// RouteIndex stores source-to-route mapping and lookup structures. +type RouteIndex struct { + ByRouteKey map[string]*RouteIndexEntry + ByNamespace map[string][]*RouteIndexEntry + Conflicts map[string][]string +} + +// BuildRouteIndex scans DQL files and builds route-key mapping. +func BuildRouteIndex(paths []string, opts ...shape.CompileOption) (*RouteIndex, error) { + compileOptions := applyCompileOptions(opts) + layout := newCompilePathLayout(compileOptions) + index := &RouteIndex{ + ByRouteKey: map[string]*RouteIndexEntry{}, + ByNamespace: map[string][]*RouteIndexEntry{}, + Conflicts: map[string][]string{}, + } + if len(paths) == 0 { + return index, nil + } + normalized := make([]string, 0, len(paths)) + for _, item := range paths { + item = strings.TrimSpace(item) + if item == "" { + continue + } + normalized = append(normalized, item) + } + sort.Strings(normalized) + + for _, sourcePath := range normalized { + data, err := os.ReadFile(sourcePath) + if err != nil { + return nil, fmt.Errorf("route index: unable to read %s: %w", sourcePath, err) + } + sourceName := strings.TrimSuffix(filepath.Base(sourcePath), filepath.Ext(sourcePath)) + dql := string(data) + _, _, directives, _ := dqlpre.Extract(dql) + source := &shape.Source{ + Name: sourceName, + Path: sourcePath, + DQL: dql, + } + settings := extractRuleSettings(source, directives) + namespace, _ := dqlToRouteNamespaceWithLayout(sourcePath, layout) + uri := strings.TrimSpace(settings.URI) + if uri == "" { + uri = inferDefaultURI(namespace) + } + if uri == "" { + continue + } + methods := parseRouteMethods(settings.Method) + for _, method := range methods { + entry := &RouteIndexEntry{ + Method: method, + URI: normalizeURI(uri), + SourcePath: sourcePath, + Namespace: namespace, + } + entry.RouteKey = normalizeRouteKey(entry.Method, entry.URI) + index.addEntry(entry) + } + } + return index, nil +} + +// Resolve maps a component reference from current source context to route key. +// It returns false when route cannot be resolved deterministically. +func (r *RouteIndex) Resolve(ref, currentSource string, opts ...shape.CompileOption) (string, bool) { + if r == nil { + return "", false + } + method, value := splitRouteKey(ref) + value = strings.TrimSpace(value) + if value == "" { + return "", false + } + layout := newCompilePathLayout(applyCompileOptions(opts)) + routeKeyFromURI := func(uri string) (string, bool) { + key := normalizeRouteKey(method, uri) + if _, conflicted := r.Conflicts[key]; conflicted { + return "", false + } + if _, ok := r.ByRouteKey[key]; !ok { + return "", false + } + return key, true + } + + if strings.HasPrefix(value, "/v1/api/") || strings.HasPrefix(value, "v1/api/") || strings.HasPrefix(value, "/") { + return routeKeyFromURI(value) + } + + if strings.TrimSpace(currentSource) == "" { + return "", false + } + _, _, dqlRoot, ok := sourceRootsWithLayout(currentSource, layout) + if !ok { + return "", false + } + sourceNamespace, _ := dqlToRouteNamespaceWithLayout(currentSource, layout) + namespace := resolveComponentNamespaceWithNamespace(value, currentSource, dqlRoot, sourceNamespace) + if namespace == "" { + return "", false + } + entries := r.ByNamespace[strings.ToLower(strings.TrimSpace(namespace))] + if len(entries) == 0 { + return "", false + } + if len(entries) == 1 { + key := entries[0].RouteKey + if _, conflicted := r.Conflicts[key]; conflicted { + return "", false + } + return key, true + } + // Multiple methods under one namespace: require exact method match. + for _, candidate := range entries { + if candidate == nil { + continue + } + if strings.EqualFold(candidate.Method, method) { + if _, conflicted := r.Conflicts[candidate.RouteKey]; conflicted { + return "", false + } + return candidate.RouteKey, true + } + } + return "", false +} + +func (r *RouteIndex) addEntry(entry *RouteIndexEntry) { + if r == nil || entry == nil { + return + } + key := entry.RouteKey + if prev, exists := r.ByRouteKey[key]; exists && prev != nil && prev.SourcePath != entry.SourcePath { + if _, ok := r.Conflicts[key]; !ok { + r.Conflicts[key] = []string{prev.SourcePath} + } + r.Conflicts[key] = append(r.Conflicts[key], entry.SourcePath) + return + } + r.ByRouteKey[key] = entry + nsKey := strings.ToLower(strings.TrimSpace(entry.Namespace)) + if nsKey != "" { + r.ByNamespace[nsKey] = append(r.ByNamespace[nsKey], entry) + } +} + +func parseRouteMethods(input string) []string { + input = strings.TrimSpace(input) + if input == "" { + return []string{http.MethodGet} + } + parts := strings.Split(input, ",") + ret := make([]string, 0, len(parts)) + seen := map[string]bool{} + for _, part := range parts { + method := strings.ToUpper(strings.TrimSpace(part)) + if method == "" { + continue + } + if seen[method] { + continue + } + seen[method] = true + ret = append(ret, method) + } + if len(ret) == 0 { + return []string{http.MethodGet} + } + return ret +} + +func normalizeRouteKey(method, uri string) string { + method = strings.ToUpper(strings.TrimSpace(method)) + if method == "" { + method = http.MethodGet + } + return method + ":" + normalizeURI(uri) +} + +func normalizeURI(uri string) string { + uri = strings.TrimSpace(uri) + if uri == "" { + return "/" + } + if strings.HasPrefix(uri, "v1/api/") { + uri = "/" + uri + } + if !strings.HasPrefix(uri, "/") { + uri = "/" + uri + } + return uri +} + +func inferDefaultURI(namespace string) string { + namespace = strings.Trim(strings.TrimSpace(namespace), "/") + if namespace == "" { + return "" + } + parts := strings.Split(namespace, "/") + if len(parts) >= 2 && parts[len(parts)-1] == parts[len(parts)-2] { + parts = parts[:len(parts)-1] + } + return "/v1/api/" + strings.Join(parts, "/") +} diff --git a/repository/shape/compile/span.go b/repository/shape/compile/span.go new file mode 100644 index 000000000..154ff9b2f --- /dev/null +++ b/repository/shape/compile/span.go @@ -0,0 +1,10 @@ +package compile + +import ( + dqlpre "github.com/viant/datly/repository/shape/dql/preprocess" + dqlshape "github.com/viant/datly/repository/shape/dql/shape" +) + +func relationSpan(raw string, offset int) dqlshape.Span { + return dqlpre.PointSpan(raw, offset) +} diff --git a/repository/shape/compile/statedecl.go b/repository/shape/compile/statedecl.go new file mode 100644 index 000000000..bd401b11d --- /dev/null +++ b/repository/shape/compile/statedecl.go @@ -0,0 +1,307 @@ +package compile + +import ( + "strconv" + "strings" + + "github.com/viant/datly/repository/shape/plan" + "github.com/viant/datly/view/extension" + st "github.com/viant/datly/view/state" + "github.com/viant/parsly" +) + +func appendDeclaredStates(rawDQL string, result *plan.Result) { + if result == nil || strings.TrimSpace(rawDQL) == "" { + return + } + seen := map[string]bool{} + for _, block := range extractSetBlocks(rawDQL) { + holder, kind, location, tail, ok := parseSetDeclarationBody(block.Body) + if !ok { + continue + } + if kind == "view" || kind == "data_view" { + continue + } + key := declaredStateKey(holder, kind, location) + if seen[key] { + continue + } + state := &plan.State{ + Parameter: st.Parameter{ + Name: holder, + In: &st.Location{ + Kind: st.Kind(kind), + Name: location, + }, + }, + } + if inType, outType := parseSetDeclarationTypes(block.Body); inType != "" || outType != "" { + ensureStateSchema(state).DataType = inType + state.OutputDataType = outType + } + switch st.Kind(strings.ToLower(kind)) { + case st.KindQuery: + required := false + state.Required = &required + case st.KindHeader: + required := true + state.Required = &required + } + applyDeclaredStateOptions(state, tail) + result.States = append(result.States, state) + seen[key] = true + } +} + +func declaredStateKey(name, kind, in string) string { + return strings.ToLower(strings.TrimSpace(name)) + "|" + + strings.ToLower(strings.TrimSpace(kind)) + "|" + + strings.ToLower(strings.TrimSpace(in)) +} + +func applyDeclaredStateOptions(state *plan.State, tail string) { + if state == nil || strings.TrimSpace(tail) == "" { + return + } + cursor := newOptionCursor(tail) + for cursor.next() { + name, args := cursor.option() + switch { + case strings.EqualFold(name, "WithURI"): + if len(args) == 1 { + state.URI = trimQuote(args[0]) + } + case strings.EqualFold(name, "Optional"): + required := false + state.Required = &required + case strings.EqualFold(name, "Required"): + required := true + state.Required = &required + case strings.EqualFold(name, "Cacheable"): + if len(args) == 1 { + if value, err := strconv.ParseBool(strings.TrimSpace(trimQuote(args[0]))); err == nil { + state.Cacheable = &value + } + } + case strings.EqualFold(name, "QuerySelector"): + if len(args) == 1 { + state.QuerySelector = trimQuote(args[0]) + if state.Cacheable == nil { + cacheable := false + state.Cacheable = &cacheable + } + } + case strings.EqualFold(name, "WithPredicate"), strings.EqualFold(name, "Predicate"): + appendStatePredicate(state, args, false) + case strings.EqualFold(name, "EnsurePredicate"): + appendStatePredicate(state, args, true) + case strings.EqualFold(name, "When"): + if len(args) == 1 { + state.When = trimQuote(args[0]) + } + case strings.EqualFold(name, "Scope"): + if len(args) == 1 { + state.Scope = trimQuote(args[0]) + } + case strings.EqualFold(name, "WithType"): + if len(args) == 1 { + ensureStateSchema(state).DataType = trimQuote(args[0]) + } + case strings.EqualFold(name, "WithCodec"): + if len(args) >= 1 { + state.Output = &st.Codec{ + Name: trimQuote(args[0]), + Args: append([]string{}, trimQuotedArgs(args[1:])...), + } + } + case strings.EqualFold(name, "WithStatusCode"): + if len(args) == 1 { + if value, err := strconv.Atoi(strings.TrimSpace(trimQuote(args[0]))); err == nil { + state.ErrorStatusCode = value + } + } + case strings.EqualFold(name, "WithErrorMessage"): + if len(args) == 1 { + state.ErrorMessage = trimQuote(args[0]) + } + case strings.EqualFold(name, "Value"): + if len(args) == 1 { + state.Value = trimQuote(args[0]) + } + case strings.EqualFold(name, "Async"): + state.Async = true + } + } +} + +func parseSetDeclarationTypes(body string) (string, string) { + cursor := parsly.NewCursor("", []byte(body), 0) + if cursor.MatchAfterOptional(vdWhitespaceMatcher, vdParamDeclMatcher).Code != vdParamDeclToken { + return "", "" + } + if _, matched := readIdentifier(cursor); !matched { + return "", "" + } + _ = cursor.MatchOne(vdWhitespaceMatcher) + matchedType := cursor.MatchOne(vdTypeMatcher) + if matchedType.Code != vdTypeToken { + return "", "" + } + typeExpr := strings.TrimSpace(matchedType.Text(cursor)) + if len(typeExpr) < 2 { + return "", "" + } + typeExpr = strings.TrimSpace(typeExpr[1 : len(typeExpr)-1]) + if typeExpr == "" { + return "", "" + } + args := splitArgs(typeExpr) + if len(args) == 0 { + return "", "" + } + inputType := strings.TrimSpace(trimQuote(args[0])) + outputType := "" + if len(args) > 1 { + outputType = strings.TrimSpace(trimQuote(args[1])) + } + if inputType == "?" { + inputType = "" + } + if outputType == "?" { + outputType = "" + } + return inputType, outputType +} + +func trimQuotedArgs(input []string) []string { + if len(input) == 0 { + return nil + } + result := make([]string, 0, len(input)) + for _, item := range input { + result = append(result, trimQuote(item)) + } + return result +} + +func appendStatePredicate(state *plan.State, args []string, ensure bool) { + if state == nil || len(args) == 0 { + return + } + group := 0 + nameIdx := 0 + if len(args) >= 2 { + if parsed, err := strconv.Atoi(strings.TrimSpace(trimQuote(args[0]))); err == nil { + group = parsed + nameIdx = 1 + } + } + if len(args) <= nameIdx { + return + } + predicate := &extension.PredicateConfig{ + Group: group, + Name: trimQuote(args[nameIdx]), + Ensure: ensure, + Args: []string{}, + } + for _, arg := range args[nameIdx+1:] { + predicate.Args = append(predicate.Args, trimQuote(arg)) + } + state.Predicates = append(state.Predicates, predicate) +} + +func ensureStateSchema(state *plan.State) *st.Schema { + if state.Schema == nil { + state.Schema = &st.Schema{} + } + return state.Schema +} + +type optionCursor struct { + raw string + cursor int + name string + args []string +} + +func newOptionCursor(raw string) *optionCursor { + return &optionCursor{raw: raw} +} + +func (o *optionCursor) next() bool { + o.name = "" + o.args = nil + for o.cursor < len(o.raw) && (o.raw[o.cursor] == ' ' || o.raw[o.cursor] == '\n' || o.raw[o.cursor] == '\t' || o.raw[o.cursor] == '\r') { + o.cursor++ + } + if o.cursor >= len(o.raw) || o.raw[o.cursor] != '.' { + return false + } + o.cursor++ + start := o.cursor + for o.cursor < len(o.raw) { + ch := o.raw[o.cursor] + if (ch >= 'a' && ch <= 'z') || (ch >= 'A' && ch <= 'Z') || (ch >= '0' && ch <= '9') || ch == '_' { + o.cursor++ + continue + } + break + } + if o.cursor == start { + return false + } + o.name = strings.TrimSpace(o.raw[start:o.cursor]) + for o.cursor < len(o.raw) && (o.raw[o.cursor] == ' ' || o.raw[o.cursor] == '\n' || o.raw[o.cursor] == '\t' || o.raw[o.cursor] == '\r') { + o.cursor++ + } + if o.cursor >= len(o.raw) || o.raw[o.cursor] != '(' { + return false + } + groupStart := o.cursor + depth := 0 + inSingle := false + inDouble := false + escape := false + for o.cursor < len(o.raw) { + ch := o.raw[o.cursor] + if escape { + escape = false + o.cursor++ + continue + } + switch ch { + case '\\': + escape = true + case '\'': + if !inDouble { + inSingle = !inSingle + } + case '"': + if !inSingle { + inDouble = !inDouble + } + case '(': + if !inSingle && !inDouble { + depth++ + } + case ')': + if !inSingle && !inDouble { + depth-- + if depth == 0 { + o.cursor++ + content := o.raw[groupStart+1 : o.cursor-1] + o.args = splitArgs(content) + return true + } + } + } + o.cursor++ + } + return false +} + +func (o *optionCursor) option() (string, []string) { + return o.name, o.args +} diff --git a/repository/shape/compile/statedecl_test.go b/repository/shape/compile/statedecl_test.go new file mode 100644 index 000000000..33c9241a4 --- /dev/null +++ b/repository/shape/compile/statedecl_test.go @@ -0,0 +1,79 @@ +package compile + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/viant/datly/repository/shape/plan" +) + +func TestAppendDeclaredStates(t *testing.T) { + dql := ` +#set($_ = $Jwt(header/Authorization).WithCodec(JwtClaim).WithStatusCode(401)) +#set($_ = $Claims(header/Authorization).WithCodec(JwtClaim)) +#set($_ = $Name(query/name).WithPredicate(0,'contains','sl','NAME').Optional()) +#set($_ = $Fields<[]string>(query/fields).QuerySelector(site_list)) +#set($_ = $Meta(output/summary)) +SELECT id FROM SITE_LIST sl` + result := &plan.Result{} + appendDeclaredStates(dql, result) + require.NotEmpty(t, result.States) + + byName := map[string]*plan.State{} + for _, item := range result.States { + if item != nil { + byName[item.Name] = item + } + } + require.NotNil(t, byName["Jwt"]) + assert.Equal(t, "header", byName["Jwt"].KindString()) + assert.Equal(t, "string", byName["Jwt"].Schema.DataType) + assert.Equal(t, "JwtClaim", byName["Jwt"].Output.Name) + assert.Equal(t, 401, byName["Jwt"].ErrorStatusCode) + require.NotNil(t, byName["Jwt"].Required) + assert.True(t, *byName["Jwt"].Required) + + require.NotNil(t, byName["Claims"]) + assert.Equal(t, "string", byName["Claims"].Schema.DataType) + assert.Equal(t, "*JwtClaims", byName["Claims"].OutputDataType) + + require.NotNil(t, byName["Name"]) + assert.Equal(t, "query", byName["Name"].KindString()) + require.NotNil(t, byName["Name"].Required) + assert.False(t, *byName["Name"].Required) + require.Len(t, byName["Name"].Predicates, 1) + assert.Equal(t, "contains", byName["Name"].Predicates[0].Name) + assert.Equal(t, 0, byName["Name"].Predicates[0].Group) + + require.NotNil(t, byName["Fields"]) + assert.Equal(t, "site_list", byName["Fields"].QuerySelector) + require.NotNil(t, byName["Fields"].Cacheable) + assert.False(t, *byName["Fields"].Cacheable) +} + +func TestAppendDeclaredStates_DuplicateDeclaration_FirstWins(t *testing.T) { + dql := ` +#set($_ = $Active(query/active).WithPredicate(0,'equal','tas','IS_TARGETABLE').Optional()) +#set($_ = $Active(query/active).WithPredicate(0,'equal','tas','ACTIVE').Optional()) +SELECT id FROM CI_TV_AFFILIATE_STATION tas` + result := &plan.Result{} + appendDeclaredStates(dql, result) + require.Len(t, result.States, 1) + require.Len(t, result.States[0].Predicates, 1) + assert.Equal(t, "Active", result.States[0].Name) + assert.Equal(t, "IS_TARGETABLE", result.States[0].Predicates[0].Args[1]) +} + +func TestAppendDeclaredStates_SupportsDefineDirective(t *testing.T) { + dql := ` +#define($_ = $Auth(header/Authorization).Required()) +SELECT id FROM USERS u` + result := &plan.Result{} + appendDeclaredStates(dql, result) + require.Len(t, result.States, 1) + assert.Equal(t, "Auth", result.States[0].Name) + assert.Equal(t, "header", result.States[0].KindString()) + require.NotNil(t, result.States[0].Required) + assert.True(t, *result.States[0].Required) +} diff --git a/repository/shape/compile/type_support.go b/repository/shape/compile/type_support.go new file mode 100644 index 000000000..44a8b9bca --- /dev/null +++ b/repository/shape/compile/type_support.go @@ -0,0 +1,238 @@ +package compile + +import ( + "reflect" + "strings" + + "github.com/viant/datly/repository/shape" + "github.com/viant/datly/repository/shape/plan" + "github.com/viant/datly/repository/shape/typectx" + "github.com/viant/x" +) + +func applyLinkedTypeSupport(result *plan.Result, source *shape.Source) { + if result == nil || source == nil { + return + } + registry := source.EnsureTypeRegistry() + if registry == nil || len(registry.Keys()) == 0 { + return + } + resolver := typectx.NewResolver(registry, result.TypeContext) + rootTypeKey := resolveRootTypeKey(source, resolver, registry) + existing := existingTypesByName(result.Types) + + for idx, item := range result.Views { + if item == nil { + continue + } + resolvedKey := resolveViewTypeKey(item, idx == 0, rootTypeKey, resolver, registry) + if resolvedKey == "" { + continue + } + resolvedType := registry.Lookup(resolvedKey) + if resolvedType == nil || resolvedType.Type == nil { + continue + } + rType := unwrapResolvedType(resolvedType.Type) + if rType == nil { + continue + } + typeExpr, typePkg := schemaTypeExpression(rType, result.TypeContext) + if shouldSetSchemaType(item) && typeExpr != "" { + item.SchemaType = typeExpr + } + name := strings.TrimSpace(rType.Name()) + if name == "" { + continue + } + key := strings.ToLower(name) + if existing[key] { + continue + } + result.Types = append(result.Types, &plan.Type{ + Name: name, + DataType: typeExpr, + Cardinality: strings.TrimSpace(item.Cardinality), + Package: typePkg, + ModulePath: strings.TrimSpace(rType.PkgPath()), + }) + existing[key] = true + } +} + +func resolveRootTypeKey(source *shape.Source, resolver *typectx.Resolver, registry *x.Registry) string { + if source == nil || registry == nil { + return "" + } + if key := resolveTypeKey(strings.TrimSpace(source.TypeName), resolver, registry); key != "" { + return key + } + rType, err := source.ResolveRootType() + if err != nil || rType == nil { + return "" + } + return resolveTypeKey(x.NewType(rType).Key(), resolver, registry) +} + +func resolveViewTypeKey(item *plan.View, root bool, rootTypeKey string, resolver *typectx.Resolver, registry *x.Registry) string { + if item == nil || registry == nil { + return "" + } + candidates := make([]string, 0, 8) + seen := map[string]bool{} + appendCandidate := func(value string) { + value = strings.TrimSpace(value) + if value == "" { + return + } + if seen[value] { + return + } + seen[value] = true + candidates = append(candidates, value) + } + + if root && rootTypeKey != "" { + appendCandidate(rootTypeKey) + } + if item.Declaration != nil { + appendCandidate(item.Declaration.DataType) + appendCandidate(item.Declaration.Of) + } + appendCandidate(item.SchemaType) + name := toExportedTypeName(item.Name) + if name != "" { + appendCandidate(name + "View") + appendCandidate(name) + } + for _, candidate := range candidates { + if key := resolveTypeKey(candidate, resolver, registry); key != "" { + return key + } + } + return "" +} + +func resolveTypeKey(typeExpr string, resolver *typectx.Resolver, registry *x.Registry) string { + if registry == nil { + return "" + } + base := normalizeTypeLookupKey(typeExpr) + if base == "" { + return "" + } + if registry.Lookup(base) != nil { + return base + } + if resolver == nil { + return "" + } + resolved, err := resolver.Resolve(base) + if err != nil || resolved == "" { + return "" + } + if registry.Lookup(resolved) == nil { + return "" + } + return resolved +} + +func normalizeTypeLookupKey(typeExpr string) string { + value := strings.TrimSpace(typeExpr) + for { + switch { + case strings.HasPrefix(value, "*"): + value = strings.TrimPrefix(value, "*") + case strings.HasPrefix(value, "[]"): + value = strings.TrimPrefix(value, "[]") + default: + return strings.TrimSpace(value) + } + } +} + +func shouldSetSchemaType(item *plan.View) bool { + if item == nil { + return false + } + current := strings.TrimSpace(item.SchemaType) + if current == "" { + return true + } + expectedDefault := "*" + toExportedTypeName(item.Name) + "View" + return current == expectedDefault +} + +func existingTypesByName(input []*plan.Type) map[string]bool { + result := map[string]bool{} + for _, item := range input { + if item == nil { + continue + } + name := strings.ToLower(strings.TrimSpace(item.Name)) + if name == "" { + continue + } + result[name] = true + } + return result +} + +func schemaTypeExpression(rType reflect.Type, ctx *typectx.Context) (string, string) { + rType = unwrapResolvedType(rType) + if rType == nil { + return "", "" + } + typeName := strings.TrimSpace(rType.Name()) + if typeName == "" { + return "", "" + } + pkgPath := strings.TrimSpace(rType.PkgPath()) + if pkgPath == "" { + return "*" + typeName, "" + } + pkgAlias := packageAlias(pkgPath, ctx) + if pkgAlias == "" { + return "*" + typeName, "" + } + return "*" + pkgAlias + "." + typeName, pkgAlias +} + +func packageAlias(pkgPath string, ctx *typectx.Context) string { + pkgPath = strings.TrimSpace(pkgPath) + if pkgPath == "" { + return "" + } + if ctx != nil { + for _, item := range ctx.Imports { + if strings.TrimSpace(item.Package) != pkgPath { + continue + } + alias := strings.TrimSpace(item.Alias) + if alias != "" { + return alias + } + } + if strings.TrimSpace(ctx.PackagePath) == pkgPath && strings.TrimSpace(ctx.PackageName) != "" { + return strings.TrimSpace(ctx.PackageName) + } + } + index := strings.LastIndex(pkgPath, "/") + if index == -1 || index+1 >= len(pkgPath) { + return pkgPath + } + return pkgPath[index+1:] +} + +func unwrapResolvedType(rType reflect.Type) reflect.Type { + for rType != nil { + switch rType.Kind() { + case reflect.Ptr, reflect.Slice, reflect.Array: + rType = rType.Elem() + default: + return rType + } + } + return nil +} diff --git a/repository/shape/compile/type_support_test.go b/repository/shape/compile/type_support_test.go new file mode 100644 index 000000000..b081ad1e6 --- /dev/null +++ b/repository/shape/compile/type_support_test.go @@ -0,0 +1,70 @@ +package compile + +import ( + "context" + "reflect" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/viant/datly/repository/shape" + "github.com/viant/datly/repository/shape/plan" + "github.com/viant/x" +) + +type linkedRootType struct { + ID int +} + +type OrdersView struct { + ID int +} + +func TestDQLCompiler_Compile_UsesLinkedRootTypeForSchemaType(t *testing.T) { + compiler := New() + source := &shape.Source{ + Name: "orders_report", + Type: reflect.TypeOf(linkedRootType{}), + TypeName: x.NewType(reflect.TypeOf(linkedRootType{})).Key(), + DQL: "SELECT t.id FROM ORDERS t", + } + + res, err := compiler.Compile(context.Background(), source) + require.NoError(t, err) + planned, ok := plan.ResultFrom(res) + require.True(t, ok) + require.NotEmpty(t, planned.Views) + assert.Equal(t, "*compile.linkedRootType", planned.Views[0].SchemaType) + require.NotEmpty(t, planned.Types) + assert.Equal(t, "linkedRootType", planned.Types[0].Name) + assert.Equal(t, "*compile.linkedRootType", planned.Types[0].DataType) +} + +func TestDQLCompiler_Compile_UsesLinkedRegistryTypeForNamedView(t *testing.T) { + compiler := New() + registry := x.NewRegistry() + registry.Register(x.NewType(reflect.TypeOf(OrdersView{}))) + source := &shape.Source{ + Name: "orders", + TypeRegistry: registry, + DQL: "SELECT orders.id FROM ORDERS orders", + } + + res, err := compiler.Compile(context.Background(), source) + require.NoError(t, err) + planned, ok := plan.ResultFrom(res) + require.True(t, ok) + require.NotEmpty(t, planned.Views) + assert.Equal(t, "*compile.OrdersView", planned.Views[0].SchemaType) + + var found *plan.Type + for _, item := range planned.Types { + if item != nil && item.Name == "OrdersView" { + found = item + break + } + } + require.NotNil(t, found) + assert.Equal(t, "*compile.OrdersView", found.DataType) + assert.Equal(t, "compile", found.Package) +} diff --git a/repository/shape/compile/typectx_defaults.go b/repository/shape/compile/typectx_defaults.go new file mode 100644 index 000000000..5561d36a3 --- /dev/null +++ b/repository/shape/compile/typectx_defaults.go @@ -0,0 +1,226 @@ +package compile + +import ( + "os" + "path" + "path/filepath" + "strings" + + "github.com/viant/datly/repository/shape" + "github.com/viant/datly/repository/shape/typectx" + "golang.org/x/mod/modfile" +) + +func applyTypeContextDefaults(ctx *typectx.Context, source *shape.Source, opts *shape.CompileOptions, layout compilePathLayout) *typectx.Context { + ret := cloneTypeContext(ctx) + if shouldInferTypeContext(opts) { + ret = mergeTypeContext(ret, inferDatlyGenTypeContext(source, layout)) + } + if opts != nil { + ret = ensureTypeContext(ret) + if ret != nil { + if value := strings.TrimSpace(opts.TypePackageDir); value != "" { + ret.PackageDir = value + } + if value := strings.TrimSpace(opts.TypePackageName); value != "" { + ret.PackageName = value + } + if value := strings.TrimSpace(opts.TypePackagePath); value != "" { + ret.PackagePath = value + } + } + } + ret = normalizeRelativeImports(ret, source, layout) + return normalizeTypeContext(ret) +} + +func shouldInferTypeContext(opts *shape.CompileOptions) bool { + if opts == nil || opts.InferTypeContext == nil { + return true + } + return *opts.InferTypeContext +} + +func mergeTypeContext(dst *typectx.Context, src *typectx.Context) *typectx.Context { + if src == nil { + return dst + } + dst = ensureTypeContext(dst) + if strings.TrimSpace(dst.DefaultPackage) == "" { + dst.DefaultPackage = strings.TrimSpace(src.DefaultPackage) + } + if len(dst.Imports) == 0 && len(src.Imports) > 0 { + dst.Imports = append([]typectx.Import{}, src.Imports...) + } + if strings.TrimSpace(dst.PackageDir) == "" { + dst.PackageDir = strings.TrimSpace(src.PackageDir) + } + if strings.TrimSpace(dst.PackageName) == "" { + dst.PackageName = strings.TrimSpace(src.PackageName) + } + if strings.TrimSpace(dst.PackagePath) == "" { + dst.PackagePath = strings.TrimSpace(src.PackagePath) + } + return dst +} + +func inferDatlyGenTypeContext(source *shape.Source, layout compilePathLayout) *typectx.Context { + parsed, ok := parseSourceLayout(source, layout) + if !ok { + return nil + } + routeDir := strings.Trim(path.Dir(parsed.relativePath), "/") + if routeDir == "." { + routeDir = "" + } + packageDir := "pkg" + if routeDir != "" { + packageDir = path.Join(packageDir, routeDir) + } + packageName := "main" + if routeDir != "" { + packageName = path.Base(routeDir) + } + packagePath := "" + if module := detectModulePath(parsed.projectRoot); module != "" { + packagePath = path.Join(module, packageDir) + } + return normalizeTypeContext(&typectx.Context{ + PackageDir: packageDir, + PackageName: packageName, + PackagePath: packagePath, + }) +} + +func detectModulePath(projectRoot string) string { + if strings.TrimSpace(projectRoot) == "" { + return "" + } + goModPath := filepath.Join(projectRoot, "go.mod") + data, err := os.ReadFile(goModPath) + if err != nil { + return "" + } + parsed, err := modfile.Parse(goModPath, data, nil) + if err != nil || parsed == nil || parsed.Module == nil { + return "" + } + return strings.TrimSpace(parsed.Module.Mod.Path) +} + +func ensureTypeContext(ctx *typectx.Context) *typectx.Context { + if ctx != nil { + return ctx + } + return &typectx.Context{} +} + +func cloneTypeContext(ctx *typectx.Context) *typectx.Context { + if ctx == nil { + return nil + } + ret := &typectx.Context{ + DefaultPackage: strings.TrimSpace(ctx.DefaultPackage), + PackageDir: strings.TrimSpace(ctx.PackageDir), + PackageName: strings.TrimSpace(ctx.PackageName), + PackagePath: strings.TrimSpace(ctx.PackagePath), + } + if len(ctx.Imports) > 0 { + ret.Imports = append([]typectx.Import{}, ctx.Imports...) + } + return ret +} + +func normalizeTypeContext(ctx *typectx.Context) *typectx.Context { + if ctx == nil { + return nil + } + if strings.TrimSpace(ctx.DefaultPackage) == "" && + len(ctx.Imports) == 0 && + strings.TrimSpace(ctx.PackageDir) == "" && + strings.TrimSpace(ctx.PackageName) == "" && + strings.TrimSpace(ctx.PackagePath) == "" { + return nil + } + return ctx +} + +func normalizeRelativeImports(ctx *typectx.Context, source *shape.Source, layout compilePathLayout) *typectx.Context { + if ctx == nil || len(ctx.Imports) == 0 { + return ctx + } + modulePath := modulePathForSource(source, layout) + if modulePath == "" { + return ctx + } + for i, item := range ctx.Imports { + pkg := strings.TrimSpace(item.Package) + if pkg == "" { + continue + } + ctx.Imports[i].Package = normalizeImportPackage(pkg, modulePath) + } + return ctx +} + +func modulePathForSource(source *shape.Source, layout compilePathLayout) string { + parsed, ok := parseSourceLayout(source, layout) + if !ok { + return "" + } + return detectModulePath(parsed.projectRoot) +} + +func normalizeImportPackage(pkg, modulePath string) string { + pkg = strings.Trim(strings.ReplaceAll(strings.TrimSpace(pkg), "\\", "/"), "/") + if pkg == "" { + return "" + } + if !strings.Contains(pkg, "/") { + return pkg + } + if strings.HasPrefix(pkg, modulePath+"/") || pkg == modulePath { + return pkg + } + first := pkg + if index := strings.Index(first, "/"); index != -1 { + first = first[:index] + } + if strings.Contains(first, ".") { + return pkg + } + return path.Join(modulePath, pkg) +} + +type sourceLayout struct { + projectRoot string + relativePath string +} + +func parseSourceLayout(source *shape.Source, layout compilePathLayout) (*sourceLayout, bool) { + if source == nil { + return nil, false + } + sourcePath := strings.TrimSpace(source.Path) + if sourcePath == "" { + return nil, false + } + marker := strings.TrimSpace(layout.dqlMarker) + if marker == "" { + marker = defaultCompilePathLayout().dqlMarker + } + normalizedPath := filepath.ToSlash(filepath.Clean(sourcePath)) + idx := strings.Index(normalizedPath, marker) + if idx == -1 { + return nil, false + } + projectRoot := filepath.FromSlash(strings.TrimSuffix(normalizedPath[:idx], "/")) + relativePath := strings.TrimPrefix(normalizedPath[idx+len(marker):], "/") + if relativePath == "" { + return nil, false + } + return &sourceLayout{ + projectRoot: projectRoot, + relativePath: relativePath, + }, true +} diff --git a/repository/shape/compile/typectx_defaults_test.go b/repository/shape/compile/typectx_defaults_test.go new file mode 100644 index 000000000..aa3d01d8c --- /dev/null +++ b/repository/shape/compile/typectx_defaults_test.go @@ -0,0 +1,86 @@ +package compile + +import ( + "os" + "path/filepath" + "testing" + + "github.com/stretchr/testify/require" + "github.com/viant/datly/repository/shape" + "github.com/viant/datly/repository/shape/typectx" +) + +func TestApplyTypeContextDefaults_Matrix(t *testing.T) { + layout := defaultCompilePathLayout() + + projectDir := t.TempDir() + err := os.WriteFile(filepath.Join(projectDir, "go.mod"), []byte("module github.vianttech.com/viant/platform\n\ngo 1.23\n"), 0o644) + require.NoError(t, err) + source := &shape.Source{ + Path: filepath.Join(projectDir, "dql", "platform", "taxonomy", "taxonomy.dql"), + } + + t.Run("inferred only", func(t *testing.T) { + got := applyTypeContextDefaults(nil, source, nil, layout) + require.NotNil(t, got) + require.Equal(t, "pkg/platform/taxonomy", got.PackageDir) + require.Equal(t, "taxonomy", got.PackageName) + require.Equal(t, "github.vianttech.com/viant/platform/pkg/platform/taxonomy", got.PackagePath) + }) + + t.Run("directive context wins over inferred", func(t *testing.T) { + input := &typectx.Context{ + DefaultPackage: "github.com/acme/manual", + PackageDir: "pkg/manual", + PackageName: "manual", + PackagePath: "github.com/acme/manual", + } + got := applyTypeContextDefaults(input, source, nil, layout) + require.NotNil(t, got) + require.Equal(t, "pkg/manual", got.PackageDir) + require.Equal(t, "manual", got.PackageName) + require.Equal(t, "github.com/acme/manual", got.PackagePath) + require.Equal(t, "github.com/acme/manual", got.DefaultPackage) + }) + + t.Run("compile override wins over both", func(t *testing.T) { + input := &typectx.Context{ + PackageDir: "pkg/manual", + PackageName: "manual", + PackagePath: "github.com/acme/manual", + } + got := applyTypeContextDefaults(input, source, &shape.CompileOptions{ + TypePackageDir: "pkg/override", + TypePackageName: "override", + TypePackagePath: "github.com/acme/override", + }, layout) + require.NotNil(t, got) + require.Equal(t, "pkg/override", got.PackageDir) + require.Equal(t, "override", got.PackageName) + require.Equal(t, "github.com/acme/override", got.PackagePath) + }) + + t.Run("explicitly disable inference", func(t *testing.T) { + disabled := false + got := applyTypeContextDefaults(nil, source, &shape.CompileOptions{ + InferTypeContext: &disabled, + }, layout) + require.Nil(t, got) + }) + + t.Run("relative imports are normalized to module path", func(t *testing.T) { + input := &typectx.Context{ + Imports: []typectx.Import{ + {Alias: "sess", Package: "pkg/platform/system/session"}, + {Alias: "perf", Package: "github.com/acme/perf"}, + {Alias: "time", Package: "time"}, + }, + } + got := applyTypeContextDefaults(input, source, nil, layout) + require.NotNil(t, got) + require.Len(t, got.Imports, 3) + require.Equal(t, "github.vianttech.com/viant/platform/pkg/platform/system/session", got.Imports[0].Package) + require.Equal(t, "github.com/acme/perf", got.Imports[1].Package) + require.Equal(t, "time", got.Imports[2].Package) + }) +} diff --git a/repository/shape/compile/typectx_diagnostics.go b/repository/shape/compile/typectx_diagnostics.go new file mode 100644 index 000000000..36cc701f9 --- /dev/null +++ b/repository/shape/compile/typectx_diagnostics.go @@ -0,0 +1,37 @@ +package compile + +import ( + "fmt" + + dqldiag "github.com/viant/datly/repository/shape/dql/diag" + dqlshape "github.com/viant/datly/repository/shape/dql/shape" + "github.com/viant/datly/repository/shape/typectx" +) + +func typeContextDiagnostics(ctx *typectx.Context, strict bool) []*dqlshape.Diagnostic { + issues := typectx.Validate(ctx) + if len(issues) == 0 { + return nil + } + severity := dqlshape.SeverityWarning + if strict { + severity = dqlshape.SeverityError + } + diags := make([]*dqlshape.Diagnostic, 0, len(issues)) + for _, issue := range issues { + if issue.Field == "" || issue.Message == "" { + continue + } + diags = append(diags, &dqlshape.Diagnostic{ + Code: dqldiag.CodeTypeCtxInvalid, + Severity: severity, + Message: fmt.Sprintf("type context %s: %s", issue.Field, issue.Message), + Hint: "set consistent TypeContext package fields or use compile overrides", + Span: dqlshape.Span{ + Start: dqlshape.Position{Line: 1, Char: 1}, + End: dqlshape.Position{Line: 1, Char: 1}, + }, + }) + } + return diags +} diff --git a/repository/shape/compile/viewdecl.go b/repository/shape/compile/viewdecl.go new file mode 100644 index 000000000..9e6c14c88 --- /dev/null +++ b/repository/shape/compile/viewdecl.go @@ -0,0 +1,107 @@ +package compile + +import ( + "fmt" + "strings" + + "github.com/viant/datly/repository/shape/compile/pipeline" + dqldiag "github.com/viant/datly/repository/shape/dql/diag" + dqlshape "github.com/viant/datly/repository/shape/dql/shape" + "github.com/viant/parsly" + "github.com/viant/parsly/matcher" +) + +type declaredView struct { + Name string + SQL string + URI string + Connector string + Cardinality string + Tag string + Codec string + CodecArgs []string + HandlerName string + HandlerArgs []string + StatusCode *int + ErrorMessage string + QuerySelector string + CacheRef string + Limit *int + Cacheable *bool + When string + Scope string + DataType string + Of string + Value string + Async bool + Output bool + Predicates []declaredPredicate +} + +type declaredPredicate struct { + Name string + Source string + Ensure bool + Arguments []string +} + +const ( + vdWhitespaceToken = iota + vdSetToken + vdDefineToken + vdExprGroupToken + vdCommentToken + vdParamDeclToken + vdTypeToken + vdDotToken +) + +var ( + vdWhitespaceMatcher = parsly.NewToken(vdWhitespaceToken, "Whitespace", matcher.NewWhiteSpace()) + vdSetMatcher = parsly.NewToken(vdSetToken, "#set", matcher.NewFragment("#set")) + vdDefineMatcher = parsly.NewToken(vdDefineToken, "#define", matcher.NewFragment("#define")) + vdExprGroupMatcher = parsly.NewToken(vdExprGroupToken, "( ... )", matcher.NewBlock('(', ')', '\\')) + vdCommentMatcher = parsly.NewToken(vdCommentToken, "Comment", matcher.NewSeqBlock("/*", "*/")) + vdParamDeclMatcher = parsly.NewToken(vdParamDeclToken, "$_ = $", matcher.NewSpacedSet([]string{"$_ = $"})) + vdTypeMatcher = parsly.NewToken(vdTypeToken, "< ... >", matcher.NewSeqBlock("<", ">")) + vdDotMatcher = parsly.NewToken(vdDotToken, ".", matcher.NewByte('.')) +) + +func extractDeclaredViews(dql string) ([]*declaredView, []*dqlshape.Diagnostic) { + if strings.TrimSpace(dql) == "" { + return nil, nil + } + var views []*declaredView + var diags []*dqlshape.Diagnostic + for _, block := range extractSetBlocks(dql) { + holder, kind, location, tail, ok := parseSetDeclarationBody(block.Body) + if !ok { + continue + } + if kind != "view" && kind != "data_view" { + continue + } + sqlText := extractDeclarationSQL(tail) + if sqlText == "" { + diags = append(diags, &dqlshape.Diagnostic{ + Code: dqldiag.CodeViewMissingSQL, + Severity: dqlshape.SeverityWarning, + Message: fmt.Sprintf("view declaration %q has no inline SQL hint", location), + Hint: "use /* SELECT ... */ in declaration comment to derive an additional view", + Span: relationSpan(dql, block.Offset), + }) + continue + } + name := pipeline.SanitizeName(location) + if name == "" { + name = pipeline.SanitizeName(holder) + } + if name == "" { + continue + } + view := &declaredView{Name: name, SQL: strings.TrimSpace(sqlText)} + applyDeclaredViewOptions(view, tail, dql, block.Offset, &diags) + views = append(views, view) + } + return views, diags +} diff --git a/repository/shape/compile/viewdecl_append.go b/repository/shape/compile/viewdecl_append.go new file mode 100644 index 000000000..6ae2fc663 --- /dev/null +++ b/repository/shape/compile/viewdecl_append.go @@ -0,0 +1,188 @@ +package compile + +import ( + "reflect" + "strings" + + "github.com/viant/datly/repository/shape/compile/pipeline" + "github.com/viant/datly/repository/shape/plan" + "github.com/viant/sqlparser" +) + +func appendDeclaredViews(rawDQL string, result *plan.Result) { + if result == nil { + return + } + declared, diags := extractDeclaredViews(rawDQL) + if len(diags) > 0 { + result.Diagnostics = append(result.Diagnostics, diags...) + } + for _, item := range declared { + if item == nil || strings.TrimSpace(item.Name) == "" || strings.TrimSpace(item.SQL) == "" { + continue + } + if parent := lookupSummaryParentView(result, item.SQL); parent != nil { + if strings.TrimSpace(parent.Summary) == "" { + parent.Summary = strings.TrimSpace(item.SQL) + } + continue + } + if _, exists := result.ViewsByName[item.Name]; exists { + continue + } + view := &plan.View{ + Path: item.Name, + Holder: item.Name, + Name: item.Name, + Table: item.Name, + SQL: item.SQL, + SQLURI: item.URI, + Connector: item.Connector, + Cardinality: "many", + FieldType: reflect.TypeOf([]map[string]interface{}{}), + ElementType: reflect.TypeOf(map[string]interface{}{}), + Declaration: buildViewDeclaration(item), + } + if item.Cardinality != "" { + view.Cardinality = item.Cardinality + } + if queryNode, err := sqlparser.ParseQuery(item.SQL); err == nil && queryNode != nil { + if inferredName, inferredTable, err := pipeline.InferRoot(queryNode, item.Name); err == nil { + view.Name = inferredName + view.Holder = inferredName + view.Path = inferredName + view.Table = inferredTable + } + if fType, eType, card := pipeline.InferProjectionType(queryNode); fType != nil && eType != nil { + view.FieldType = fType + view.ElementType = eType + if item.Cardinality == "" { + view.Cardinality = card + } + } + } + result.Views = append(result.Views, view) + result.ViewsByName[view.Name] = view + } +} + +func lookupSummaryParentView(result *plan.Result, sqlText string) *plan.View { + if result == nil || strings.TrimSpace(sqlText) == "" { + return nil + } + parent, ok := findSummaryParentReference(sqlText) + if !ok { + return nil + } + if view, ok := result.ViewsByName[parent]; ok && view != nil { + return view + } + sanitized := pipeline.SanitizeName(parent) + if sanitized != "" { + if view, ok := result.ViewsByName[sanitized]; ok && view != nil { + return view + } + } + for name, view := range result.ViewsByName { + if view == nil { + continue + } + if strings.EqualFold(strings.TrimSpace(name), parent) || (sanitized != "" && strings.EqualFold(strings.TrimSpace(name), sanitized)) { + return view + } + } + for _, view := range result.Views { + if view == nil { + continue + } + if strings.EqualFold(strings.TrimSpace(view.Name), parent) || (sanitized != "" && strings.EqualFold(strings.TrimSpace(view.Name), sanitized)) { + return view + } + } + return nil +} + +func findSummaryParentReference(input string) (string, bool) { + if strings.TrimSpace(input) == "" { + return "", false + } + lower := strings.ToLower(input) + for i := 0; i+len("$view.") < len(lower); i++ { + if lower[i] != '$' { + continue + } + if !strings.HasPrefix(lower[i:], "$view.") { + continue + } + start := i + len("$view.") + if start >= len(input) || !isCompileIdentifierStart(input[start]) { + continue + } + end := start + 1 + for end < len(input) && isCompileIdentifierPart(input[end]) { + end++ + } + if !strings.HasPrefix(lower[end:], ".sql") { + continue + } + parent := strings.TrimSpace(input[start:end]) + if parent == "" { + continue + } + return parent, true + } + return "", false +} + +func isCompileIdentifierStart(ch byte) bool { + return (ch >= 'a' && ch <= 'z') || (ch >= 'A' && ch <= 'Z') || ch == '_' +} + +func isCompileIdentifierPart(ch byte) bool { + return isCompileIdentifierStart(ch) || (ch >= '0' && ch <= '9') +} + +func buildViewDeclaration(item *declaredView) *plan.ViewDeclaration { + if item == nil { + return nil + } + ret := &plan.ViewDeclaration{ + Tag: item.Tag, + Codec: item.Codec, + CodecArgs: append([]string{}, item.CodecArgs...), + HandlerName: item.HandlerName, + HandlerArgs: append([]string{}, item.HandlerArgs...), + StatusCode: item.StatusCode, + ErrorMessage: item.ErrorMessage, + QuerySelector: item.QuerySelector, + CacheRef: item.CacheRef, + Limit: item.Limit, + Cacheable: item.Cacheable, + When: item.When, + Scope: item.Scope, + DataType: item.DataType, + Of: item.Of, + Value: item.Value, + Async: item.Async, + Output: item.Output, + } + if len(item.Predicates) > 0 { + ret.Predicates = make([]*plan.ViewPredicate, 0, len(item.Predicates)) + for _, predicate := range item.Predicates { + ret.Predicates = append(ret.Predicates, &plan.ViewPredicate{ + Name: predicate.Name, + Source: predicate.Source, + Ensure: predicate.Ensure, + Arguments: append([]string{}, predicate.Arguments...), + }) + } + } + if ret.Tag == "" && ret.Codec == "" && len(ret.CodecArgs) == 0 && ret.HandlerName == "" && + len(ret.HandlerArgs) == 0 && ret.StatusCode == nil && ret.ErrorMessage == "" && + ret.QuerySelector == "" && ret.CacheRef == "" && ret.Limit == nil && ret.Cacheable == nil && + ret.When == "" && ret.Scope == "" && ret.DataType == "" && ret.Of == "" && ret.Value == "" && + !ret.Async && !ret.Output && len(ret.Predicates) == 0 { + return nil + } + return ret +} diff --git a/repository/shape/compile/viewdecl_options.go b/repository/shape/compile/viewdecl_options.go new file mode 100644 index 000000000..dd8ea2fba --- /dev/null +++ b/repository/shape/compile/viewdecl_options.go @@ -0,0 +1,382 @@ +package compile + +import ( + "fmt" + "strconv" + "strings" + + dqldiag "github.com/viant/datly/repository/shape/dql/diag" + dqlshape "github.com/viant/datly/repository/shape/dql/shape" + "github.com/viant/parsly" +) + +func extractDeclarationSQL(fragment string) string { + cursor := parsly.NewCursor("", []byte(fragment), 0) + for cursor.Pos < cursor.InputSize { + match := cursor.MatchAfterOptional(vdWhitespaceMatcher, vdCommentMatcher) + if match.Code == vdCommentToken { + text := match.Text(cursor) + if len(text) < 4 { + return "" + } + return normalizeHintSQL(text[2 : len(text)-2]) + } + cursor.Pos++ + } + return "" +} + +func normalizeHintSQL(body string) string { + body = strings.TrimSpace(body) + if body == "" { + return "" + } + if strings.HasPrefix(body, "{") { + if closeIdx := strings.Index(body, "}"); closeIdx != -1 { + body = strings.TrimSpace(body[closeIdx+1:]) + } + } + if body == "" { + return "" + } + switch body[0] { + case '?': + body = strings.TrimSpace(body[1:]) + case '!': + body = strings.TrimSpace(body[1:]) + if strings.HasPrefix(body, "!") { + body = strings.TrimSpace(body[1:]) + } + if len(body) >= 3 { + var status int + if _, err := fmt.Sscanf(body[:3], "%d", &status); err == nil { + body = strings.TrimSpace(body[3:]) + } + } + } + return strings.TrimSpace(body) +} + +func applyDeclaredViewOptions(view *declaredView, tail, dql string, offset int, diags *[]*dqlshape.Diagnostic) { + if view == nil || strings.TrimSpace(tail) == "" { + return + } + cursor := parsly.NewCursor("", []byte(tail), 0) + for cursor.Pos < cursor.InputSize { + _ = cursor.MatchOne(vdWhitespaceMatcher) + if cursor.MatchOne(vdDotMatcher).Code != vdDotToken { + cursor.Pos++ + continue + } + _ = cursor.MatchOne(vdWhitespaceMatcher) + name, ok := readIdentifier(cursor) + if !ok { + continue + } + _ = cursor.MatchOne(vdWhitespaceMatcher) + group := cursor.MatchOne(vdExprGroupMatcher) + if group.Code != vdExprGroupToken { + continue + } + content := group.Text(cursor) + if len(content) < 2 { + continue + } + args := splitArgs(content[1 : len(content)-1]) + switch { + case strings.EqualFold(name, "WithURI"): + if !expectArgs(view, name, args, 1, -1, dql, offset, diags) { + continue + } + view.URI = trimQuote(args[0]) + case strings.EqualFold(name, "WithConnector"), strings.EqualFold(name, "Connector"): + if !expectArgs(view, name, args, 1, -1, dql, offset, diags) { + continue + } + view.Connector = trimQuote(args[0]) + case strings.EqualFold(name, "Cardinality"): + if !expectArgs(view, name, args, 1, 1, dql, offset, diags) { + continue + } + card := strings.ToLower(strings.TrimSpace(trimQuote(args[0]))) + switch card { + case "one", "many": + view.Cardinality = card + default: + *diags = append(*diags, &dqlshape.Diagnostic{ + Code: dqldiag.CodeViewCardinality, + Severity: dqlshape.SeverityWarning, + Message: fmt.Sprintf("unsupported cardinality %q for declared view %q", args[0], view.Name), + Hint: "use Cardinality('one') or Cardinality('many')", + Span: relationSpan(dql, offset), + }) + } + case strings.EqualFold(name, "WithTag"), strings.EqualFold(name, "Tag"): + if !expectArgs(view, name, args, 1, 1, dql, offset, diags) { + continue + } + view.Tag = trimQuote(args[0]) + case strings.EqualFold(name, "WithCodec"), strings.EqualFold(name, "Codec"): + if !expectArgs(view, name, args, 1, -1, dql, offset, diags) { + continue + } + view.Codec = trimQuote(args[0]) + view.CodecArgs = nil + for _, arg := range args[1:] { + view.CodecArgs = append(view.CodecArgs, strings.TrimSpace(arg)) + } + case strings.EqualFold(name, "WithHandler"), strings.EqualFold(name, "Handler"): + if !expectArgs(view, name, args, 1, -1, dql, offset, diags) { + continue + } + view.HandlerName = trimQuote(args[0]) + view.HandlerArgs = nil + for _, arg := range args[1:] { + view.HandlerArgs = append(view.HandlerArgs, strings.TrimSpace(arg)) + } + case strings.EqualFold(name, "WithStatusCode"), strings.EqualFold(name, "StatusCode"): + if !expectArgs(view, name, args, 1, 1, dql, offset, diags) { + continue + } + statusCode, err := strconv.Atoi(strings.TrimSpace(trimQuote(args[0]))) + if err != nil { + *diags = append(*diags, &dqlshape.Diagnostic{ + Code: dqldiag.CodeDeclOptionArgs, + Severity: dqlshape.SeverityWarning, + Message: fmt.Sprintf("invalid status code %q for declared view %q", args[0], view.Name), + Hint: "use numeric status code, e.g. StatusCode(400)", + Span: relationSpan(dql, offset), + }) + continue + } + view.StatusCode = &statusCode + case strings.EqualFold(name, "WithErrorMessage"): + if !expectArgs(view, name, args, 1, 1, dql, offset, diags) { + continue + } + view.ErrorMessage = trimQuote(args[0]) + case strings.EqualFold(name, "WithPredicate"), strings.EqualFold(name, "Predicate"): + if !expectArgs(view, name, args, 2, -1, dql, offset, diags) { + continue + } + view.Predicates = append(view.Predicates, declaredPredicate{ + Name: trimQuote(args[0]), + Source: trimQuote(args[1]), + Arguments: append([]string{}, args[2:]...), + }) + case strings.EqualFold(name, "EnsurePredicate"): + if !expectArgs(view, name, args, 2, -1, dql, offset, diags) { + continue + } + view.Predicates = append(view.Predicates, declaredPredicate{ + Name: trimQuote(args[0]), + Source: trimQuote(args[1]), + Ensure: true, + Arguments: append([]string{}, args[2:]...), + }) + case strings.EqualFold(name, "QuerySelector"): + if !expectArgs(view, name, args, 1, 1, dql, offset, diags) { + continue + } + view.QuerySelector = trimQuote(args[0]) + if !isAllowedQuerySelector(strings.ToLower(view.Name)) { + *diags = append(*diags, &dqlshape.Diagnostic{ + Code: dqldiag.CodeDeclQuerySelector, + Severity: dqlshape.SeverityWarning, + Message: fmt.Sprintf("query selector %q can only be used with limit, offset, page, fields, orderby", view.QuerySelector), + Hint: "use QuerySelector on declarations named limit/offset/page/fields/orderby", + Span: relationSpan(dql, offset), + }) + } + case strings.EqualFold(name, "WithCache"): + if !expectArgs(view, name, args, 1, 1, dql, offset, diags) { + continue + } + view.CacheRef = trimQuote(args[0]) + case strings.EqualFold(name, "WithLimit"): + if !expectArgs(view, name, args, 1, 1, dql, offset, diags) { + continue + } + limit, err := strconv.Atoi(strings.TrimSpace(trimQuote(args[0]))) + if err != nil { + appendOptionArgDiagnostic(view, name, fmt.Sprintf("invalid integer limit %q", args[0]), dql, offset, diags) + continue + } + view.Limit = &limit + case strings.EqualFold(name, "Cacheable"): + if !expectArgs(view, name, args, 1, 1, dql, offset, diags) { + continue + } + value, err := strconv.ParseBool(strings.TrimSpace(trimQuote(args[0]))) + if err != nil { + appendOptionArgDiagnostic(view, name, fmt.Sprintf("invalid bool cacheable %q", args[0]), dql, offset, diags) + continue + } + view.Cacheable = &value + case strings.EqualFold(name, "When"): + if !expectArgs(view, name, args, 1, 1, dql, offset, diags) { + continue + } + view.When = trimQuote(args[0]) + case strings.EqualFold(name, "Scope"): + if !expectArgs(view, name, args, 1, 1, dql, offset, diags) { + continue + } + view.Scope = trimQuote(args[0]) + case strings.EqualFold(name, "WithType"): + if !expectArgs(view, name, args, 1, 1, dql, offset, diags) { + continue + } + view.DataType = trimQuote(args[0]) + case strings.EqualFold(name, "Of"): + if !expectArgs(view, name, args, 1, 1, dql, offset, diags) { + continue + } + view.Of = trimQuote(args[0]) + case strings.EqualFold(name, "Value"): + if !expectArgs(view, name, args, 1, 1, dql, offset, diags) { + continue + } + view.Value = trimQuote(args[0]) + case strings.EqualFold(name, "Async"): + if !expectArgs(view, name, args, 0, 0, dql, offset, diags) { + continue + } + view.Async = true + case strings.EqualFold(name, "Output"): + if !expectArgs(view, name, args, 0, 0, dql, offset, diags) { + continue + } + view.Output = true + } + } +} + +func splitArgs(raw string) []string { + raw = strings.TrimSpace(raw) + if raw == "" { + return nil + } + var result []string + var current strings.Builder + inSingle := false + inDouble := false + escape := false + parens := 0 + brackets := 0 + braces := 0 + for i := 0; i < len(raw); i++ { + ch := raw[i] + if escape { + current.WriteByte(ch) + escape = false + continue + } + switch ch { + case '\\': + current.WriteByte(ch) + escape = true + case '\'': + if !inDouble { + inSingle = !inSingle + } + current.WriteByte(ch) + case '"': + if !inSingle { + inDouble = !inDouble + } + current.WriteByte(ch) + case '(': + if !inSingle && !inDouble { + parens++ + } + current.WriteByte(ch) + case ')': + if !inSingle && !inDouble && parens > 0 { + parens-- + } + current.WriteByte(ch) + case '[': + if !inSingle && !inDouble { + brackets++ + } + current.WriteByte(ch) + case ']': + if !inSingle && !inDouble && brackets > 0 { + brackets-- + } + current.WriteByte(ch) + case '{': + if !inSingle && !inDouble { + braces++ + } + current.WriteByte(ch) + case '}': + if !inSingle && !inDouble && braces > 0 { + braces-- + } + current.WriteByte(ch) + case ',': + if inSingle || inDouble || parens > 0 || brackets > 0 || braces > 0 { + current.WriteByte(ch) + continue + } + part := strings.TrimSpace(current.String()) + if part != "" { + result = append(result, part) + } + current.Reset() + default: + current.WriteByte(ch) + } + } + if tail := strings.TrimSpace(current.String()); tail != "" { + result = append(result, tail) + } + return result +} + +func trimQuote(v string) string { + v = strings.TrimSpace(v) + if len(v) >= 2 { + if (v[0] == '\'' && v[len(v)-1] == '\'') || (v[0] == '"' && v[len(v)-1] == '"') { + return v[1 : len(v)-1] + } + } + return v +} + +func expectArgs(view *declaredView, option string, args []string, min, max int, dql string, offset int, diags *[]*dqlshape.Diagnostic) bool { + if len(args) < min { + appendOptionArgDiagnostic(view, option, fmt.Sprintf("expected at least %d args, got %d", min, len(args)), dql, offset, diags) + return false + } + if max >= 0 && len(args) > max { + appendOptionArgDiagnostic(view, option, fmt.Sprintf("expected at most %d args, got %d", max, len(args)), dql, offset, diags) + return false + } + return true +} + +func appendOptionArgDiagnostic(view *declaredView, option, detail, dql string, offset int, diags *[]*dqlshape.Diagnostic) { + viewName := "" + if view != nil { + viewName = view.Name + } + *diags = append(*diags, &dqlshape.Diagnostic{ + Code: dqldiag.CodeDeclOptionArgs, + Severity: dqlshape.SeverityWarning, + Message: fmt.Sprintf("invalid %s declaration for view %q: %s", option, viewName, detail), + Hint: "check option arity and argument formatting", + Span: relationSpan(dql, offset), + }) +} + +func isAllowedQuerySelector(name string) bool { + switch strings.ToLower(strings.TrimSpace(name)) { + case "limit", "offset", "page", "fields", "orderby": + return true + default: + return false + } +} diff --git a/repository/shape/compile/viewdecl_parity_test.go b/repository/shape/compile/viewdecl_parity_test.go new file mode 100644 index 000000000..03d455521 --- /dev/null +++ b/repository/shape/compile/viewdecl_parity_test.go @@ -0,0 +1,66 @@ +package compile + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + dqldiag "github.com/viant/datly/repository/shape/dql/diag" + dqlshape "github.com/viant/datly/repository/shape/dql/shape" +) + +func TestViewDecl_ParityFixtures(t *testing.T) { + testCases := []struct { + name string + viewName string + tail string + expectDiag string + expectTag string + expectCodec string + expectHandler string + expectPreds int + }{ + { + name: "tag/codec/handler", + viewName: "limit", + tail: ".WithTag('json:\"id\"').WithCodec(AsJSON).WithHandler('Build')", + expectTag: `json:"id"`, + expectCodec: "AsJSON", + expectHandler: "Build", + }, + { + name: "status arg validation", + viewName: "limit", + tail: ".WithStatusCode('x')", + expectDiag: dqldiag.CodeDeclOptionArgs, + }, + { + name: "query selector validation", + viewName: "customer_id", + tail: ".QuerySelector('items')", + expectDiag: dqldiag.CodeDeclQuerySelector, + }, + { + name: "predicate forms", + viewName: "limit", + tail: ".WithPredicate('ByID','id=?',1).EnsurePredicate('Tenant','tenant=?',2)", + expectPreds: 2, + }, + } + for _, testCase := range testCases { + t.Run(testCase.name, func(t *testing.T) { + view := &declaredView{Name: testCase.viewName} + var diags []*dqlshape.Diagnostic + applyDeclaredViewOptions(view, testCase.tail, "SELECT 1", 0, &diags) + if testCase.expectDiag != "" { + require.NotEmpty(t, diags) + assert.Equal(t, testCase.expectDiag, diags[0].Code) + return + } + assert.Equal(t, testCase.expectTag, view.Tag) + assert.Equal(t, testCase.expectCodec, view.Codec) + assert.Equal(t, testCase.expectHandler, view.HandlerName) + assert.Len(t, view.Predicates, testCase.expectPreds) + }) + } +} diff --git a/repository/shape/compile/viewdecl_parse.go b/repository/shape/compile/viewdecl_parse.go new file mode 100644 index 000000000..51fd45a9d --- /dev/null +++ b/repository/shape/compile/viewdecl_parse.go @@ -0,0 +1,90 @@ +package compile + +import ( + "strings" + "unicode" + + "github.com/viant/parsly" +) + +type setBlock struct { + Offset int + Body string +} + +func extractSetBlocks(dql string) []setBlock { + cursor := parsly.NewCursor("", []byte(dql), 0) + var result []setBlock + for cursor.Pos < cursor.InputSize { + matched := cursor.MatchAfterOptional(vdWhitespaceMatcher, vdSetMatcher, vdDefineMatcher) + if matched.Code != vdSetToken && matched.Code != vdDefineToken { + cursor.Pos++ + continue + } + offset := cursor.Pos - len(matched.Text(cursor)) + group := cursor.MatchAfterOptional(vdWhitespaceMatcher, vdExprGroupMatcher) + if group.Code != vdExprGroupToken { + continue + } + body := group.Text(cursor) + if len(body) < 2 { + continue + } + result = append(result, setBlock{ + Offset: offset, + Body: body[1 : len(body)-1], + }) + } + return result +} + +func parseSetDeclarationBody(body string) (holder, kind, location, tail string, ok bool) { + cursor := parsly.NewCursor("", []byte(body), 0) + if cursor.MatchAfterOptional(vdWhitespaceMatcher, vdParamDeclMatcher).Code != vdParamDeclToken { + return "", "", "", "", false + } + id, matched := readIdentifier(cursor) + if !matched { + return "", "", "", "", false + } + holder = id + _ = cursor.MatchOne(vdWhitespaceMatcher) + _ = cursor.MatchOne(vdTypeMatcher) + _ = cursor.MatchOne(vdWhitespaceMatcher) + kindLoc := cursor.MatchOne(vdExprGroupMatcher) + if kindLoc.Code != vdExprGroupToken { + return "", "", "", "", false + } + inGroup := kindLoc.Text(cursor) + if len(inGroup) < 2 { + return "", "", "", "", false + } + raw := strings.TrimSpace(inGroup[1 : len(inGroup)-1]) + slash := strings.Index(raw, "/") + if slash == -1 { + return "", "", "", "", false + } + kind = strings.ToLower(strings.TrimSpace(raw[:slash])) + location = strings.TrimSpace(raw[slash+1:]) + tail = strings.TrimSpace(string(cursor.Input[cursor.Pos:])) + return holder, kind, location, tail, true +} + +func readIdentifier(cursor *parsly.Cursor) (string, bool) { + if cursor.Pos >= cursor.InputSize { + return "", false + } + start := cursor.Pos + for cursor.Pos < cursor.InputSize { + ch := rune(cursor.Input[cursor.Pos]) + if ch == '_' || ch == '$' || unicode.IsLetter(ch) || unicode.IsDigit(ch) { + cursor.Pos++ + continue + } + break + } + if cursor.Pos == start { + return "", false + } + return string(cursor.Input[start:cursor.Pos]), true +} diff --git a/repository/shape/compile/viewdecl_test.go b/repository/shape/compile/viewdecl_test.go new file mode 100644 index 000000000..0136c64a1 --- /dev/null +++ b/repository/shape/compile/viewdecl_test.go @@ -0,0 +1,187 @@ +package compile + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + dqldiag "github.com/viant/datly/repository/shape/dql/diag" + dqlshape "github.com/viant/datly/repository/shape/dql/shape" + "github.com/viant/datly/repository/shape/plan" +) + +func TestViewDecl_ExtractSetBlocks(t *testing.T) { + dql := "#set($_ = $Extra(view/extra_view) /* SELECT id FROM EXTRA e */)\n" + + "#define($_ = $Extra2(view/extra_view_2) /* SELECT id FROM EXTRA2 e */)\n" + + "SELECT id FROM ORDERS o" + blocks := extractSetBlocks(dql) + require.Len(t, blocks, 2) + assert.Contains(t, blocks[0].Body, "$Extra") + assert.Contains(t, blocks[1].Body, "$Extra2") +} + +func TestViewDecl_ParseSetDeclarationBody(t *testing.T) { + holder, kind, location, tail, ok := parseSetDeclarationBody("$_ = $Extra(view/extra_view).WithURI('/x')") + require.True(t, ok) + assert.Equal(t, "Extra", holder) + assert.Equal(t, "view", kind) + assert.Equal(t, "extra_view", location) + assert.Contains(t, tail, ".WithURI('/x')") +} + +func TestViewDecl_ApplyOptions_InvalidCardinality(t *testing.T) { + view := &declaredView{Name: "extra"} + var diags []*dqlshape.Diagnostic + applyDeclaredViewOptions(view, ".Cardinality('few')", "SELECT 1", 0, &diags) + require.NotEmpty(t, diags) + assert.Equal(t, dqldiag.CodeViewCardinality, diags[0].Code) +} + +func TestViewDecl_AppendDeclaredViews(t *testing.T) { + dql := "#set($_ = $Extra(view/extra_view).WithURI('/x') /* SELECT code FROM EXTRA e */)" + result := &plan.Result{ + ViewsByName: map[string]*plan.View{}, + ByPath: map[string]*plan.Field{}, + } + appendDeclaredViews(dql, result) + require.NotEmpty(t, result.Views) + found := false + for _, item := range result.Views { + if item != nil && item.SQLURI == "/x" { + found = true + break + } + } + assert.True(t, found) +} + +func TestViewDecl_ApplyOptions_Extended(t *testing.T) { + view := &declaredView{Name: "limit"} + var diags []*dqlshape.Diagnostic + tail := ".WithTag('json:\"id\"').WithCodec(AsJSON,'x').WithHandler('Build',a,b)." + + "WithStatusCode(422).WithErrorMessage('bad req').WithPredicate('ByID','id = ?', 101)." + + "EnsurePredicate('Tenant','tenant_id = ?', 7).QuerySelector('qs').WithCache('c1').WithLimit(10)." + + "Cacheable(true).When('x > 1').Scope('team').WithType('[]Order').Of('list').Value('abc').Async().Output()" + applyDeclaredViewOptions(view, tail, "SELECT 1", 0, &diags) + + require.Empty(t, diags) + assert.Equal(t, `json:"id"`, view.Tag) + assert.Equal(t, "AsJSON", view.Codec) + require.Len(t, view.CodecArgs, 1) + assert.Equal(t, "'x'", view.CodecArgs[0]) + assert.Equal(t, "Build", view.HandlerName) + require.Len(t, view.HandlerArgs, 2) + assert.Equal(t, "a", view.HandlerArgs[0]) + assert.Equal(t, "b", view.HandlerArgs[1]) + require.NotNil(t, view.StatusCode) + assert.Equal(t, 422, *view.StatusCode) + assert.Equal(t, "bad req", view.ErrorMessage) + require.Len(t, view.Predicates, 2) + assert.Equal(t, "ByID", view.Predicates[0].Name) + assert.False(t, view.Predicates[0].Ensure) + assert.Equal(t, "Tenant", view.Predicates[1].Name) + assert.True(t, view.Predicates[1].Ensure) + assert.Equal(t, "qs", view.QuerySelector) + assert.Equal(t, "c1", view.CacheRef) + require.NotNil(t, view.Limit) + assert.Equal(t, 10, *view.Limit) + require.NotNil(t, view.Cacheable) + assert.True(t, *view.Cacheable) + assert.Equal(t, "x > 1", view.When) + assert.Equal(t, "team", view.Scope) + assert.Equal(t, "[]Order", view.DataType) + assert.Equal(t, "list", view.Of) + assert.Equal(t, "abc", view.Value) + assert.True(t, view.Async) + assert.True(t, view.Output) +} + +func TestViewDecl_ApplyOptions_QuerySelectorValidation(t *testing.T) { + view := &declaredView{Name: "customer_id"} + var diags []*dqlshape.Diagnostic + applyDeclaredViewOptions(view, ".QuerySelector('q')", "SELECT 1", 0, &diags) + require.NotEmpty(t, diags) + assert.Equal(t, dqldiag.CodeDeclQuerySelector, diags[0].Code) +} + +func TestViewDecl_SplitArgs_Nested(t *testing.T) { + args := splitArgs(`'a', fn(1,2), {'k': [1,2]}, "x,y"`) + require.Len(t, args, 4) + assert.Equal(t, "'a'", args[0]) + assert.Equal(t, "fn(1,2)", args[1]) + assert.Equal(t, "{'k': [1,2]}", args[2]) + assert.Equal(t, `"x,y"`, args[3]) +} + +func TestViewDecl_AppendDeclaredViews_ExtendedDeclarationMetadata(t *testing.T) { + dql := "#set($_ = $limit(view/limit).WithTag('json:\"id\"').WithCodec(AsJSON).WithHandler('Build',a)." + + "WithStatusCode(409).WithErrorMessage('conflict').WithPredicate('ByID','id=?',1)." + + "EnsurePredicate('Tenant','tenant=?',2).QuerySelector('items').WithCache('c1').WithLimit(5)." + + "Cacheable(false).When('x').Scope('s').WithType('Order').Of('o').Value('v').Async().Output() /* SELECT id FROM EXTRA e */)" + result := &plan.Result{ + ViewsByName: map[string]*plan.View{}, + ByPath: map[string]*plan.Field{}, + } + appendDeclaredViews(dql, result) + require.NotEmpty(t, result.Views) + var target *plan.View + for _, item := range result.Views { + if item != nil && item.Name == "e" { + target = item + break + } + } + require.NotNil(t, target) + require.NotNil(t, target.Declaration) + assert.Equal(t, `json:"id"`, target.Declaration.Tag) + assert.Equal(t, "AsJSON", target.Declaration.Codec) + assert.Equal(t, "Build", target.Declaration.HandlerName) + require.NotNil(t, target.Declaration.StatusCode) + assert.Equal(t, 409, *target.Declaration.StatusCode) + assert.Equal(t, "conflict", target.Declaration.ErrorMessage) + assert.Equal(t, "items", target.Declaration.QuerySelector) + assert.Equal(t, "c1", target.Declaration.CacheRef) + require.NotNil(t, target.Declaration.Limit) + assert.Equal(t, 5, *target.Declaration.Limit) + require.NotNil(t, target.Declaration.Cacheable) + assert.False(t, *target.Declaration.Cacheable) + assert.Equal(t, "x", target.Declaration.When) + assert.Equal(t, "s", target.Declaration.Scope) + assert.Equal(t, "Order", target.Declaration.DataType) + assert.Equal(t, "o", target.Declaration.Of) + assert.Equal(t, "v", target.Declaration.Value) + assert.True(t, target.Declaration.Async) + assert.True(t, target.Declaration.Output) + require.Len(t, target.Declaration.Predicates, 2) +} + +func TestViewDecl_AppendDeclaredViews_AttachSummaryFromMetaViewSQL(t *testing.T) { + root := &plan.View{Name: "Browser", Path: "Browser", Holder: "Browser"} + result := &plan.Result{ + Views: []*plan.View{root}, + ViewsByName: map[string]*plan.View{"Browser": root}, + ByPath: map[string]*plan.Field{}, + } + dql := "#set($_ = $Summary(view/summary) /* SELECT COUNT(1) CNT FROM ($View.browser.SQL) t */)" + + appendDeclaredViews(dql, result) + + require.Len(t, result.Views, 1) + require.NotNil(t, root) + assert.Contains(t, root.Summary, "COUNT(1)") + assert.Contains(t, root.Summary, "$View.browser.SQL") +} + +func TestViewDecl_AppendDeclaredViews_MetaViewSQL_NoParentFallbackToView(t *testing.T) { + result := &plan.Result{ + ViewsByName: map[string]*plan.View{}, + ByPath: map[string]*plan.Field{}, + } + dql := "#set($_ = $Summary(view/summary) /* SELECT COUNT(1) CNT FROM ($View.browser.SQL) t */)" + + appendDeclaredViews(dql, result) + + require.Len(t, result.Views, 1) + assert.Empty(t, result.Views[0].Summary) + assert.NotEmpty(t, result.Views[0].Name) +} diff --git a/repository/shape/doc.go b/repository/shape/doc.go new file mode 100644 index 000000000..730ab1395 --- /dev/null +++ b/repository/shape/doc.go @@ -0,0 +1,3 @@ +// Package shape provides building blocks for dynamic repository loading from +// struct and DQL sources without requiring persisted YAML artifacts. +package shape diff --git a/repository/shape/dql/decl/lex.go b/repository/shape/dql/decl/lex.go new file mode 100644 index 000000000..fbf6270ac --- /dev/null +++ b/repository/shape/dql/decl/lex.go @@ -0,0 +1,59 @@ +package decl + +import ( + "github.com/viant/parsly" + "github.com/viant/parsly/matcher" +) + +const ( + whitespaceToken = iota + singleQuotedToken + doubleQuotedToken + commentBlockToken + parenthesesBlockToken + identifierToken + anyToken +) + +var whitespaceMatcher = parsly.NewToken(whitespaceToken, "Whitespace", matcher.NewWhiteSpace()) +var singleQuotedMatcher = parsly.NewToken(singleQuotedToken, "SingleQuote", matcher.NewBlock('\'', '\'', '\\')) +var doubleQuotedMatcher = parsly.NewToken(doubleQuotedToken, "DoubleQuote", matcher.NewBlock('"', '"', '\\')) +var commentBlockMatcher = parsly.NewToken(commentBlockToken, "CommentBlock", matcher.NewSeqBlock("/*", "*/")) +var parenthesesBlockMatcher = parsly.NewToken(parenthesesBlockToken, "Parentheses", matcher.NewBlock('(', ')', '\\')) + +var identifierMatcher = parsly.NewToken(identifierToken, "Identifier", &identifierMatch{}) +var anyMatcher = parsly.NewToken(anyToken, "Any", &anyMatch{}) + +type anyMatch struct{} + +func (a *anyMatch) Match(cursor *parsly.Cursor) int { + if cursor.Pos < cursor.InputSize { + return 1 + } + return 0 +} + +type identifierMatch struct{} + +func (i *identifierMatch) Match(cursor *parsly.Cursor) int { + if cursor.Pos >= cursor.InputSize { + return 0 + } + b := cursor.Input[cursor.Pos] + if !isIdentifierStart(b) { + return 0 + } + pos := cursor.Pos + 1 + for pos < cursor.InputSize && isIdentifierPart(cursor.Input[pos]) { + pos++ + } + return pos - cursor.Pos +} + +func isIdentifierStart(b byte) bool { + return (b >= 'a' && b <= 'z') || (b >= 'A' && b <= 'Z') || b == '_' +} + +func isIdentifierPart(b byte) bool { + return isIdentifierStart(b) || (b >= '0' && b <= '9') +} diff --git a/repository/shape/dql/decl/model.go b/repository/shape/dql/decl/model.go new file mode 100644 index 000000000..914d3493c --- /dev/null +++ b/repository/shape/dql/decl/model.go @@ -0,0 +1,42 @@ +package decl + +// Kind identifies parsed declaration function. +type Kind string + +const ( + KindCast Kind = "cast" + KindTag Kind = "tag" + KindSetLimit Kind = "set_limit" + KindAllowNulls Kind = "allow_nulls" + KindSetPartitioner Kind = "set_partitioner" + KindUseConnector Kind = "use_connector" + KindMatchStrategy Kind = "match_strategy" + KindCompressAboveSize Kind = "compress_above_size" + KindBatchSize Kind = "batch_size" + KindRelationalConcurrency Kind = "relational_concurrency" + KindPublishParent Kind = "publish_parent" + KindCardinality Kind = "cardinality" + KindPackage Kind = "package" + KindImport Kind = "import" +) + +// Declaration represents one parsed function declaration in DQL. +type Declaration struct { + Kind Kind + Raw string + Offset int + Args []string + + // Normalized fields for known declarations. + Target string // first argument (alias/column) + DataType string // cast(... as type) + Tag string // tag(..., "...") payload + Limit string // set_limit(..., N) + Connector string // use_connector(view, connector) + Strategy string // match_strategy(view, strategy) + Partition string // set_partitioner(view, partitioner, concurrency) + Size string // compress_above_size(size) + Value string // generic second argument (batch_size, relational_concurrency, cardinality) + Package string // package(default/package) + Alias string // import(alias, package/path) +} diff --git a/repository/shape/dql/decl/parser.go b/repository/shape/dql/decl/parser.go new file mode 100644 index 000000000..5fe6d13d9 --- /dev/null +++ b/repository/shape/dql/decl/parser.go @@ -0,0 +1,337 @@ +package decl + +import ( + "fmt" + "strings" + + "github.com/viant/parsly" +) + +// Parse extracts declarations from original DQL text. +func Parse(dql string) ([]*Declaration, error) { + cursor := parsly.NewCursor("", []byte(dql), 0) + var result []*Declaration + for cursor.Pos < cursor.InputSize { + matched := cursor.MatchAfterOptional(whitespaceMatcher, + commentBlockMatcher, + singleQuotedMatcher, + doubleQuotedMatcher, + identifierMatcher, + anyMatcher, + ) + switch matched.Code { + case identifierToken: + name := strings.ToLower(matched.Text(cursor)) + callOffset := matched.Offset + block := cursor.MatchAfterOptional(whitespaceMatcher, parenthesesBlockMatcher) + if block.Code != parenthesesBlockToken { + continue + } + rawCall := name + block.Text(cursor) + argsText := block.Text(cursor) + if len(argsText) < 2 { + continue + } + args := splitArgs(argsText[1 : len(argsText)-1]) + if rewrittenName, rewrittenArgs, ok := unwrapSetSpecial(name, args); ok { + name = rewrittenName + args = rewrittenArgs + rawCall = name + "(" + strings.Join(args, ", ") + ")" + } + decl := &Declaration{ + Kind: parseKind(name), + Raw: rawCall, + Offset: callOffset, + Args: args, + } + normalizeDeclaration(decl) + result = append(result, decl) + case parsly.Invalid: + return nil, cursor.NewError(identifierMatcher) + } + } + return result, nil +} + +func parseKind(name string) Kind { + switch strings.ToLower(name) { + case "cast": + return KindCast + case "tag": + return KindTag + case "set_limit": + return KindSetLimit + case "allow_nulls": + return KindAllowNulls + case "set_partitioner": + return KindSetPartitioner + case "use_connector": + return KindUseConnector + case "match_strategy": + return KindMatchStrategy + case "compress_above_size": + return KindCompressAboveSize + case "batch_size": + return KindBatchSize + case "relational_concurrency": + return KindRelationalConcurrency + case "publish_parent": + return KindPublishParent + case "cardinality": + return KindCardinality + case "package": + return KindPackage + case "import": + return KindImport + default: + return Kind(name) + } +} + +func normalizeDeclaration(decl *Declaration) { + if decl == nil || len(decl.Args) == 0 { + return + } + decl.Target = strings.TrimSpace(decl.Args[0]) + switch decl.Kind { + case KindCast: + if len(decl.Args) >= 2 { + decl.DataType = normalizeCastType(decl.Args[1]) + } else if len(decl.Args) == 1 { + target, dataType := splitCastExpression(decl.Args[0]) + if target != "" { + decl.Target = target + } + if dataType != "" { + decl.DataType = dataType + } + } + case KindTag: + if len(decl.Args) >= 2 { + decl.Tag = trimQuotes(strings.TrimSpace(decl.Args[1])) + } + case KindSetLimit: + if len(decl.Args) >= 2 { + decl.Limit = strings.TrimSpace(decl.Args[1]) + } + case KindUseConnector: + if len(decl.Args) >= 2 { + decl.Connector = trimQuotes(strings.TrimSpace(decl.Args[1])) + } + case KindMatchStrategy: + if len(decl.Args) >= 2 { + decl.Strategy = trimQuotes(strings.TrimSpace(decl.Args[1])) + } + case KindSetPartitioner: + if len(decl.Args) >= 2 { + decl.Partition = trimQuotes(strings.TrimSpace(decl.Args[1])) + } + if len(decl.Args) >= 3 { + decl.Value = strings.TrimSpace(decl.Args[2]) + } + case KindCompressAboveSize: + if len(decl.Args) >= 1 { + decl.Size = strings.TrimSpace(decl.Args[0]) + } + case KindBatchSize, KindRelationalConcurrency, KindCardinality: + if len(decl.Args) >= 2 { + decl.Value = trimQuotes(strings.TrimSpace(decl.Args[1])) + } + case KindPackage: + decl.Package = trimQuotes(strings.TrimSpace(decl.Args[0])) + case KindImport: + switch len(decl.Args) { + case 1: + decl.Package = trimQuotes(strings.TrimSpace(decl.Args[0])) + default: + decl.Alias = trimQuotes(strings.TrimSpace(decl.Args[0])) + decl.Package = trimQuotes(strings.TrimSpace(decl.Args[1])) + } + } +} + +func splitCastExpression(expr string) (string, string) { + text := strings.TrimSpace(expr) + if text == "" { + return "", "" + } + lowered := strings.ToLower(text) + quote := rune(0) + escape := false + depth := 0 + for i := 0; i < len(text); i++ { + r := rune(text[i]) + if quote != 0 { + if escape { + escape = false + continue + } + if r == '\\' { + escape = true + continue + } + if r == quote { + quote = 0 + } + continue + } + switch r { + case '\'', '"', '`': + quote = r + case '(': + depth++ + case ')': + if depth > 0 { + depth-- + } + } + if depth == 0 && i+4 <= len(text) { + chunk := lowered[i : i+4] + if chunk == " as " { + left := strings.TrimSpace(text[:i]) + right := strings.TrimSpace(text[i+4:]) + return left, trimQuotes(right) + } + } + } + return text, "" +} + +func normalizeCastType(arg string) string { + text := strings.TrimSpace(arg) + lower := strings.ToLower(text) + if strings.HasPrefix(lower, "as ") { + text = strings.TrimSpace(text[3:]) + } + return trimQuotes(text) +} + +func unwrapSetSpecial(name string, args []string) (string, []string, bool) { + if strings.ToLower(strings.TrimSpace(name)) != "set" || len(args) != 1 { + return "", nil, false + } + expr := args[0] + if expr == "" { + return "", nil, false + } + for _, functionName := range []string{"package", "import"} { + token := "$" + functionName + "(" + lowerExpr := strings.ToLower(expr) + idx := strings.Index(lowerExpr, token) + if idx == -1 { + continue + } + openPos := idx + len(token) - 1 + closePos := findClosingParen(expr, openPos) + if closePos <= openPos { + continue + } + inner := strings.TrimSpace(expr[openPos+1 : closePos]) + return functionName, splitArgs(inner), true + } + return "", nil, false +} + +func findClosingParen(text string, openPos int) int { + if openPos < 0 || openPos >= len(text) || text[openPos] != '(' { + return -1 + } + depth := 0 + var quote rune + escape := false + runes := []rune(text) + for i := 0; i < len(runes); i++ { + r := runes[i] + if quote != 0 { + if escape { + escape = false + continue + } + if r == '\\' { + escape = true + continue + } + if r == quote { + quote = 0 + } + continue + } + switch r { + case '\'', '"', '`': + quote = r + case '(': + depth++ + case ')': + depth-- + if depth == 0 { + return i + } + } + } + return -1 +} + +func splitArgs(text string) []string { + var result []string + start := 0 + depth := 0 + var quote rune + escape := false + runes := []rune(text) + for i, r := range runes { + if quote != 0 { + if escape { + escape = false + continue + } + if r == '\\' { + escape = true + continue + } + if r == quote { + quote = 0 + } + continue + } + switch r { + case '\'', '"', '`': + quote = r + case '(': + depth++ + case ')': + if depth > 0 { + depth-- + } + case ',': + if depth == 0 { + result = append(result, strings.TrimSpace(string(runes[start:i]))) + start = i + 1 + } + } + } + last := strings.TrimSpace(string(runes[start:])) + if last != "" || strings.TrimSpace(text) != "" { + result = append(result, last) + } + return result +} + +func trimQuotes(value string) string { + value = strings.TrimSpace(value) + if len(value) < 2 { + return value + } + first := value[0] + last := value[len(value)-1] + if (first == '\'' && last == '\'') || (first == '"' && last == '"') || (first == '`' && last == '`') { + return value[1 : len(value)-1] + } + return value +} + +func (d *Declaration) String() string { + if d == nil { + return "" + } + return fmt.Sprintf("%s(%s)", d.Kind, strings.Join(d.Args, ", ")) +} diff --git a/repository/shape/dql/decl/parser_test.go b/repository/shape/dql/decl/parser_test.go new file mode 100644 index 000000000..f5939822d --- /dev/null +++ b/repository/shape/dql/decl/parser_test.go @@ -0,0 +1,142 @@ +package decl + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestParse_ExtractsDeclarations(t *testing.T) { + sql := ` +SELECT ad.*, + cast(ad.ACTIVE as bool), + cast(ad.CHANNELS AS '[]string'), + tag(ad.CHANNELS, 'sqlx:"-"'), + set_limit(ad, 25), + allow_nulls(ad) +FROM CI_AD_ORDER ad` + decls, err := Parse(sql) + require.NoError(t, err) + require.Len(t, decls, 5) + + require.Equal(t, KindCast, decls[0].Kind) + require.Equal(t, "ad.ACTIVE", decls[0].Target) + require.Equal(t, "bool", decls[0].DataType) + + require.Equal(t, KindCast, decls[1].Kind) + require.Equal(t, "[]string", decls[1].DataType) + + require.Equal(t, KindTag, decls[2].Kind) + require.Equal(t, `sqlx:"-"`, decls[2].Tag) + + require.Equal(t, KindSetLimit, decls[3].Kind) + require.Equal(t, "25", decls[3].Limit) + + require.Equal(t, KindAllowNulls, decls[4].Kind) + require.Equal(t, "ad", decls[4].Target) +} + +func TestParse_IgnoresQuotedAndCommented(t *testing.T) { + sql := ` +SELECT + 'cast(a as int)', + "tag(x,'json')", + /* set_limit(a,1), allow_nulls(a) */ + cast(t.ACTIVE as bool) +FROM T t` + decls, err := Parse(sql) + require.NoError(t, err) + require.Len(t, decls, 1) + require.Equal(t, KindCast, decls[0].Kind) +} + +func TestParse_SupportsNestedArgs(t *testing.T) { + sql := `SELECT tag(ad.NAME, concat('a,', upper('b'))), set_limit(ad, ifnull(25, 10)) FROM T ad` + decls, err := Parse(sql) + require.NoError(t, err) + require.Len(t, decls, 2) + require.Equal(t, KindTag, decls[0].Kind) + require.Equal(t, "concat('a,', upper('b'))", decls[0].Tag) + require.Equal(t, "ifnull(25, 10)", decls[1].Limit) +} + +func TestParse_AllowsWhitespaceBetweenNameAndParen(t *testing.T) { + sql := `SELECT cast (ad.ACTIVE as bool), allow_nulls ( ad ) FROM T ad` + decls, err := Parse(sql) + require.NoError(t, err) + require.Len(t, decls, 2) + require.Equal(t, KindCast, decls[0].Kind) + require.Equal(t, KindAllowNulls, decls[1].Kind) +} + +func TestParse_InvalidInputProducesNoError(t *testing.T) { + sql := `SELECT cast ad.ACTIVE as bool FROM T` + decls, err := Parse(sql) + require.NoError(t, err) + require.Len(t, decls, 0) +} + +func TestParse_ExtractsExtendedSettings(t *testing.T) { + sql := ` +SELECT x.*, + set_partitioner(x, 'pkg.Part', 7), + use_connector(x, 'bq_mdp'), + match_strategy(x, 'read_all'), + compress_above_size(1024), + batch_size(x, 20000), + relational_concurrency(x, 10), + publish_parent(x), + cardinality(x, 'One') +FROM T x` + decls, err := Parse(sql) + require.NoError(t, err) + require.Len(t, decls, 8) + + require.Equal(t, KindSetPartitioner, decls[0].Kind) + require.Equal(t, "pkg.Part", decls[0].Partition) + require.Equal(t, "7", decls[0].Value) + + require.Equal(t, KindUseConnector, decls[1].Kind) + require.Equal(t, "bq_mdp", decls[1].Connector) + + require.Equal(t, KindMatchStrategy, decls[2].Kind) + require.Equal(t, "read_all", decls[2].Strategy) + + require.Equal(t, KindCompressAboveSize, decls[3].Kind) + require.Equal(t, "1024", decls[3].Size) + + require.Equal(t, KindBatchSize, decls[4].Kind) + require.Equal(t, "20000", decls[4].Value) + + require.Equal(t, KindRelationalConcurrency, decls[5].Kind) + require.Equal(t, "10", decls[5].Value) + + require.Equal(t, KindPublishParent, decls[6].Kind) + require.Equal(t, "x", decls[6].Target) + + require.Equal(t, KindCardinality, decls[7].Kind) + require.Equal(t, "One", decls[7].Value) +} + +func TestParse_ExtractsPackageAndImport(t *testing.T) { + sql := ` +#set($_ = $package('mdp/performance')) +#set($_ = $import('perf', 'github.com/acme/mdp/performance')) +#set($_ = $import('github.com/acme/shared/types')) +SELECT x.* +FROM T x` + decls, err := Parse(sql) + require.NoError(t, err) + require.Len(t, decls, 3) + + require.Equal(t, KindPackage, decls[0].Kind) + require.Equal(t, "mdp/performance", decls[0].Package) + + require.Equal(t, KindImport, decls[1].Kind) + require.Equal(t, "perf", decls[1].Alias) + require.Equal(t, "github.com/acme/mdp/performance", decls[1].Package) + + require.Equal(t, KindImport, decls[2].Kind) + require.Equal(t, "", decls[2].Alias) + require.Equal(t, "github.com/acme/shared/types", decls[2].Package) +} diff --git a/repository/shape/dql/diag/codes.go b/repository/shape/dql/diag/codes.go new file mode 100644 index 000000000..7fa6a96e9 --- /dev/null +++ b/repository/shape/dql/diag/codes.go @@ -0,0 +1,48 @@ +package diag + +const ( + CodeParseEmpty = "DQL-PARSE-EMPTY" + CodeParseSyntax = "DQL-PARSE-SYNTAX" + CodeParseUnknownNonRead = "DQL-PARSE-UNKNOWN-NONREAD" + + CodeDirPackage = "DQL-DIR-PACKAGE" + CodeDirImport = "DQL-DIR-IMPORT" + CodeDirMeta = "DQL-DIR-META" + CodeDirCache = "DQL-DIR-CACHE" + CodeDirMCP = "DQL-DIR-MCP" + CodeDirConnector = "DQL-DIR-CONNECTOR" + CodeDirRoute = "DQL-DIR-ROUTE" + CodeDirMarshal = "DQL-DIR-MARSHAL" + CodeDirUnmarshal = "DQL-DIR-UNMARSHAL" + CodeDirFormat = "DQL-DIR-FORMAT" + CodeDirDateFormat = "DQL-DIR-DATE-FORMAT" + CodeDirCaseFormat = "DQL-DIR-CASE-FORMAT" + CodeDirUnsupported = "DQL-DIR-UNSUPPORTED" + + CodeOptParse = "DQL-OPT-PARSE" + CodeSQLIRawSelector = "DQL-SQLI-RAW-SELECTOR" + CodeViewMissingSQL = "DQL-VIEW-MISSING-SQL" + CodeViewCardinality = "DQL-VIEW-CARDINALITY" + CodeDeclOptionArgs = "DQL-DECL-OPTION-ARGS" + CodeDeclQuerySelector = "DQL-DECL-QUERY-SELECTOR" + CodeRelMissingON = "DQL-REL-MISSING-ON" + CodeRelUnsupported = "DQL-REL-UNSUPPORTED-PREDICATE" + CodeRelAmbiguous = "DQL-REL-AMBIGUOUS-LINK" + CodeRelNoLinks = "DQL-REL-NO-LINKS" + CodeCompRefInvalid = "DQL-COMP-REF-INVALID" + CodeCompRouteMissing = "DQL-COMP-ROUTE-MISSING" + CodeCompRouteInvalid = "DQL-COMP-ROUTE-INVALID" + CodeCompCycle = "DQL-COMP-CYCLE" + CodeCompTypeCollision = "DQL-COMP-TYPE-COLLISION" + CodeTypeCtxInvalid = "DQL-TYPECTX-INVALID" + CodeDMLMixed = "DQL-DML-MIXED" + CodeDMLServiceArg = "DQL-DML-SERVICE-ARG" + CodeDMLInsert = "DQL-DML-INSERT" + CodeDMLUpdate = "DQL-DML-UPDATE" + CodeDMLDelete = "DQL-DML-DELETE" + CodeColDiscoveryReq = "DQL-COL-DISCOVERY-REQUIRED" + + PrefixRel = "DQL-REL-" + PrefixComp = "DQL-COMP-" + PrefixSQLI = "DQL-SQLI-" +) diff --git a/repository/shape/dql/holder/model.go b/repository/shape/dql/holder/model.go new file mode 100644 index 000000000..eac005500 --- /dev/null +++ b/repository/shape/dql/holder/model.go @@ -0,0 +1,128 @@ +package holder + +// ComponentHolder is a meta/tag-driven canonical holder for DQL/YAML parity +// and conversion to Datly internal/YAML representation. +type ComponentHolder struct { + Route RouteShape `shape:"route"` + Component ComponentShape `shape:"component"` + Input IOShape `shape:"input"` + Output IOShape `shape:"output"` + ViewGraph ViewGraphShape `shape:"views"` + Dependencies DependencyShape `shape:"deps"` + Meta map[string]string `shape:"meta"` +} + +type RouteShape struct { + Name string `shape:"route.name"` + URI string `shape:"route.uri"` + Method string `shape:"route.method"` + Service string `shape:"route.service"` + Description string `shape:"route.description"` + MCPTool bool `shape:"route.mcpTool"` + ViewRef string `shape:"route.viewRef"` +} + +type ComponentShape struct { + Name string `shape:"component.name"` + Package string `shape:"component.package"` + SourceURL string `shape:"component.sourceURL"` + Handler string `shape:"component.handler"` + Settings map[string]string `shape:"component.settings"` + Dependencies []string `shape:"component.dependencies"` +} + +type IOShape struct { + TypeName string `shape:"io.typeName"` + Package string `shape:"io.package"` + Cardinality string `shape:"io.cardinality"` + CaseFormat string `shape:"io.caseFormat"` + Exclude []string `shape:"io.exclude"` + Parameters []ParameterShape `shape:"io.parameters"` +} + +type ParameterShape struct { + Name string `shape:"param.name"` + Kind string `shape:"param.kind"` + In string `shape:"param.in"` + Required *bool `shape:"param.required"` + DataType string `shape:"param.dataType"` + Package string `shape:"param.package"` + Cardinality string `shape:"param.cardinality"` + Tag string `shape:"param.tag"` + TagMeta map[string]string `shape:"param.tagMeta"` + CodecName string `shape:"param.codec.name"` + CodecArgs []string `shape:"param.codec.args"` + ErrorStatusCode int `shape:"param.errorStatusCode"` + Cacheable *bool `shape:"param.cacheable"` + Scope string `shape:"param.scope"` + Connector string `shape:"param.connector"` + Limit *int `shape:"param.limit"` + Value string `shape:"param.value"` + Predicates []PredicateShape `shape:"param.predicates"` + LocationInput *LocationShape `shape:"param.locationInput"` +} + +type PredicateShape struct { + Name string `shape:"predicate.name"` + Group int `shape:"predicate.group"` + Ensure bool `shape:"predicate.ensure"` + Args []string `shape:"predicate.args"` +} + +type LocationShape struct { + Name string `shape:"location.name"` + Package string `shape:"location.package"` + Parameters []ParameterShape `shape:"location.parameters"` +} + +type ViewGraphShape struct { + Root string `shape:"views.root"` + Views []ViewShape `shape:"views.items"` +} + +type ViewShape struct { + Name string `shape:"view.name"` + Mode string `shape:"view.mode"` + Table string `shape:"view.table"` + Module string `shape:"view.module"` + AllowNulls *bool `shape:"view.allowNulls"` + Connector string `shape:"view.connector"` + Partitioner string `shape:"view.partitioner"` + PartitionedConcurrency int `shape:"view.partitionedConcurrency"` + RelationalConcurrency int `shape:"view.relationalConcurrency"` + SourceURL string `shape:"view.sourceURL"` + Selector SelectorShape `shape:"view.selector"` + With []RelationShape `shape:"view.with"` + Columns map[string]string `shape:"view.columns"` +} + +type SelectorShape struct { + Namespace string `shape:"selector.namespace"` + Limit *int `shape:"selector.limit"` + Criteria *bool `shape:"selector.criteria"` + Projection *bool `shape:"selector.projection"` + OrderBy *bool `shape:"selector.orderBy"` + Offset *bool `shape:"selector.offset"` +} + +type RelationShape struct { + Name string `shape:"relation.name"` + Holder string `shape:"relation.holder"` + Cardinality string `shape:"relation.cardinality"` + IncludeColumn *bool `shape:"relation.includeColumn"` + Ref string `shape:"relation.ref"` + On []JoinShape `shape:"relation.on"` +} + +type JoinShape struct { + Namespace string `shape:"join.namespace"` + Column string `shape:"join.column"` + Field string `shape:"join.field"` +} + +type DependencyShape struct { + With []string `shape:"deps.with"` + Connectors []string `shape:"deps.connectors"` + Constants []string `shape:"deps.constants"` + Substitutions []string `shape:"deps.substitutions"` +} diff --git a/repository/shape/dql/holder/model_test.go b/repository/shape/dql/holder/model_test.go new file mode 100644 index 000000000..da73b4052 --- /dev/null +++ b/repository/shape/dql/holder/model_test.go @@ -0,0 +1,52 @@ +package holder + +import ( + "reflect" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestComponentHolder_CoversRequiredSemantics(t *testing.T) { + required := []string{ + "route.uri", "route.method", "route.service", "route.viewRef", + "component.sourceURL", "component.settings", "component.dependencies", + "io.typeName", "io.parameters", "io.exclude", "io.caseFormat", + "param.name", "param.kind", "param.in", "param.required", + "param.dataType", "param.cardinality", "param.tag", "param.tagMeta", + "param.codec.name", "param.codec.args", "param.predicates", + "param.errorStatusCode", "param.cacheable", "param.scope", "param.connector", "param.limit", "param.value", + "view.name", "view.mode", "view.table", "view.connector", "view.partitioner", "view.partitionedConcurrency", "view.relationalConcurrency", "view.sourceURL", "view.selector", "view.with", + "selector.namespace", "selector.limit", "selector.criteria", "selector.projection", "selector.orderBy", "selector.offset", + "relation.name", "relation.holder", "relation.cardinality", "relation.ref", "relation.on", + "join.namespace", "join.column", "join.field", + "deps.with", "deps.connectors", "deps.constants", "deps.substitutions", + } + + got := collectShapeTags(reflect.TypeOf(ComponentHolder{}), map[string]struct{}{}, map[reflect.Type]bool{}) + for _, item := range required { + _, ok := got[item] + require.Truef(t, ok, "missing semantic tag %q in holder model", item) + } +} + +func collectShapeTags(t reflect.Type, acc map[string]struct{}, visited map[reflect.Type]bool) map[string]struct{} { + for t.Kind() == reflect.Ptr || t.Kind() == reflect.Slice || t.Kind() == reflect.Array { + t = t.Elem() + } + if t.Kind() != reflect.Struct { + return acc + } + if visited[t] { + return acc + } + visited[t] = true + for i := 0; i < t.NumField(); i++ { + field := t.Field(i) + if tag := field.Tag.Get("shape"); tag != "" { + acc[tag] = struct{}{} + } + collectShapeTags(field.Type, acc, visited) + } + return acc +} diff --git a/repository/shape/dql/ir/model.go b/repository/shape/dql/ir/model.go new file mode 100644 index 000000000..c35b16929 --- /dev/null +++ b/repository/shape/dql/ir/model.go @@ -0,0 +1,24 @@ +package ir + +import ( + "fmt" + + "gopkg.in/yaml.v3" +) + +// Document represents DQL internal representation independent of YAML rendering. +// Root carries the route/resource model as generic tree. +type Document struct { + Root map[string]any +} + +func FromYAML(data []byte) (*Document, error) { + if len(data) == 0 { + return nil, fmt.Errorf("dql ir: empty source") + } + var root map[string]any + if err := yaml.Unmarshal(data, &root); err != nil { + return nil, err + } + return &Document{Root: root}, nil +} diff --git a/repository/shape/dql/load/loader.go b/repository/shape/dql/load/loader.go new file mode 100644 index 000000000..341a3e654 --- /dev/null +++ b/repository/shape/dql/load/loader.go @@ -0,0 +1,125 @@ +package load + +import ( + "context" + "fmt" + "strings" + + "github.com/viant/datly/repository/shape" + dqlplan "github.com/viant/datly/repository/shape/dql/plan" + shapeplan "github.com/viant/datly/repository/shape/plan" + "github.com/viant/datly/repository/shape/scan" +) + +// Artifact carries canonical representation for parity checks. +type Artifact struct { + Canonical map[string]any +} + +func FromPlan(result *dqlplan.Result) *Artifact { + if result == nil { + return nil + } + return &Artifact{Canonical: result.Canonical} +} + +// FromHolderStruct builds a canonical shape artifact directly from a tagged holder struct. +func FromHolderStruct(ctx context.Context, holder any) (*Artifact, error) { + if holder == nil { + return nil, fmt.Errorf("dql load: holder was nil") + } + scanned, err := scan.New().Scan(ctx, &shape.Source{Struct: holder}) + if err != nil { + return nil, err + } + planned, err := shapeplan.New().Plan(ctx, scanned) + if err != nil { + return nil, err + } + shapeResult, ok := shapeplan.ResultFrom(planned) + if !ok { + return nil, fmt.Errorf("dql load: unsupported shape plan kind %q", planned.Plan.ShapeSpecKind()) + } + views := make([]any, 0, len(shapeResult.Views)) + for _, item := range shapeResult.Views { + if item == nil { + continue + } + entry := map[string]any{ + "Name": item.Name, + "Table": item.Table, + "ConnectorRef": item.Connector, + "Holder": item.Holder, + "Cardinality": item.Cardinality, + } + if item.Partitioner != "" { + entry["Partitioner"] = item.Partitioner + } + if item.PartitionedConcurrency > 0 { + entry["PartitionedConcurrency"] = item.PartitionedConcurrency + } + if item.RelationalConcurrency > 0 { + entry["RelationalConcurrency"] = item.RelationalConcurrency + } + if item.Ref != "" { + entry["Ref"] = item.Ref + } + if item.SQLURI != "" { + entry["SourceURL"] = item.SQLURI + } + if item.SQL != "" { + entry["SQL"] = item.SQL + } + if links := relationLinks(item); len(links) > 0 { + entry["Links"] = links + } + views = append(views, entry) + } + return &Artifact{ + Canonical: map[string]any{ + "Resource": map[string]any{ + "Views": views, + }, + }, + }, nil +} + +func relationLinks(item *shapeplan.View) []string { + if item == nil || len(item.Relations) == 0 { + return nil + } + var result []string + for _, relation := range item.Relations { + if relation == nil || len(relation.On) == 0 { + continue + } + for _, on := range relation.On { + if on == nil { + continue + } + expr := strings.TrimSpace(on.Expression) + if expr == "" { + left := selector(on.ParentNamespace, on.ParentColumn) + right := selector(on.RefNamespace, on.RefColumn) + if left == "" || right == "" { + continue + } + expr = left + "=" + right + } + result = append(result, expr) + } + } + return result +} + +func selector(namespace, column string) string { + column = strings.TrimSpace(column) + if column == "" { + return "" + } + namespace = strings.TrimSpace(namespace) + if namespace == "" { + return column + } + return namespace + "." + column +} diff --git a/repository/shape/dql/load/loader_test.go b/repository/shape/dql/load/loader_test.go new file mode 100644 index 000000000..6d6e36629 --- /dev/null +++ b/repository/shape/dql/load/loader_test.go @@ -0,0 +1,58 @@ +package load + +import ( + "context" + "testing" + + "github.com/stretchr/testify/require" +) + +type sampleView struct { + ID int +} + +type manyHolder struct { + Rows *[]sampleView `view:"rows,table=CI_SAMPLE,connector=ci_ads,partitioner=custom.Partitioner,concurrency=4,relationalConcurrency=2" sql:"SELECT ID FROM CI_SAMPLE"` +} + +type oneHolder struct { + Row *sampleView `view:"row,table=CI_SAMPLE,connector=ci_ads"` +} + +func TestFromHolderStruct_ManyCardinality(t *testing.T) { + artifact, err := FromHolderStruct(context.Background(), &manyHolder{}) + require.NoError(t, err) + require.NotNil(t, artifact) + + resource, ok := artifact.Canonical["Resource"].(map[string]any) + require.True(t, ok) + views, ok := resource["Views"].([]any) + require.True(t, ok) + require.Len(t, views, 1) + view, ok := views[0].(map[string]any) + require.True(t, ok) + require.Equal(t, "rows", view["Name"]) + require.Equal(t, "CI_SAMPLE", view["Table"]) + require.Equal(t, "ci_ads", view["ConnectorRef"]) + require.Equal(t, "Rows", view["Holder"]) + require.Equal(t, "many", view["Cardinality"]) + require.Equal(t, "custom.Partitioner", view["Partitioner"]) + require.EqualValues(t, 4, view["PartitionedConcurrency"]) + require.EqualValues(t, 2, view["RelationalConcurrency"]) +} + +func TestFromHolderStruct_OneCardinality(t *testing.T) { + artifact, err := FromHolderStruct(context.Background(), &oneHolder{}) + require.NoError(t, err) + require.NotNil(t, artifact) + + resource, ok := artifact.Canonical["Resource"].(map[string]any) + require.True(t, ok) + views, ok := resource["Views"].([]any) + require.True(t, ok) + require.Len(t, views, 1) + view, ok := views[0].(map[string]any) + require.True(t, ok) + require.Equal(t, "row", view["Name"]) + require.Equal(t, "one", view["Cardinality"]) +} diff --git a/repository/shape/dql/optimize/optimizer.go b/repository/shape/dql/optimize/optimizer.go new file mode 100644 index 000000000..030f26fb4 --- /dev/null +++ b/repository/shape/dql/optimize/optimizer.go @@ -0,0 +1,99 @@ +package optimize + +import ( + "strings" + + dqldiag "github.com/viant/datly/repository/shape/dql/diag" + dqlshape "github.com/viant/datly/repository/shape/dql/shape" + "github.com/viant/velty" + "github.com/viant/velty/ast" + aexpr "github.com/viant/velty/ast/expr" +) + +// Rewrite applies lightweight template simplification and emits diagnostics. +// It is intentionally conservative: only dead #if(false) blocks without else are blanked. +func Rewrite(dql string) (string, []*dqlshape.Diagnostic) { + if strings.TrimSpace(dql) == "" { + return dql, nil + } + adjuster := &hookAdjuster{source: []byte(dql), seenOffset: map[int]struct{}{}} + out, err := velty.TransformTemplate([]byte(dql), adjuster) + if err != nil { + adjuster.diagnostics = append(adjuster.diagnostics, &dqlshape.Diagnostic{ + Code: dqldiag.CodeOptParse, + Severity: dqlshape.SeverityWarning, + Message: "velty optimization pass skipped due to parse issue", + Hint: "check template syntax near directives and expressions", + Span: dqlshape.Span{Start: dqlshape.Position{Line: 1, Char: 1}, End: dqlshape.Position{Line: 1, Char: 1}}, + }) + return dql, adjuster.diagnostics + } + return string(out), adjuster.diagnostics +} + +type hookAdjuster struct { + source []byte + seenOffset map[int]struct{} + diagnostics []*dqlshape.Diagnostic +} + +func (a *hookAdjuster) Adjust(node ast.Node, ctx *velty.ParserContext) (velty.Action, error) { + switch actual := node.(type) { + case *aexpr.Select: + a.captureSQLInjectionRisk(actual, ctx) + } + return velty.Keep(), nil +} + +func (a *hookAdjuster) captureSQLInjectionRisk(sel *aexpr.Select, ctx *velty.ParserContext) { + if sel == nil || ctx == nil { + return + } + if ctx.CurrentExprContext().Kind == velty.CtxSetLHS { + return + } + span, ok := ctx.GetSpan(sel) + if !ok { + return + } + if strings.EqualFold(sel.ID, "Nop") { + return + } + if a.inNoopCall(span.Start) { + return + } + if _, exists := a.seenOffset[span.Start]; exists { + return + } + a.seenOffset[span.Start] = struct{}{} + pos := ctx.ResolvePosition(span) + a.diagnostics = append(a.diagnostics, &dqlshape.Diagnostic{ + Code: dqldiag.CodeSQLIRawSelector, + Severity: dqlshape.SeverityWarning, + Message: "raw selector interpolation detected in SQL template", + Hint: "prefer bind parameters or validated allow-listed fragments", + Span: dqlshape.Span{ + Start: dqlshape.Position{Offset: span.Start, Line: pos.Line, Char: pos.Col}, + End: dqlshape.Position{Offset: span.End, Line: pos.EndLine, Char: pos.EndCol}, + }, + }) +} + +func (a *hookAdjuster) inNoopCall(pos int) bool { + if pos <= 0 || pos > len(a.source) { + return false + } + prefix := string(a.source[:pos]) + nopPos := strings.LastIndex(prefix, "$Nop(") + if nopPos == -1 { + nopPos = strings.LastIndex(prefix, "$nop(") + } + if nopPos == -1 { + return false + } + if nl := strings.LastIndex(prefix, "\n"); nl > nopPos { + return false + } + segment := prefix[nopPos:pos] + return strings.Count(segment, "(") > strings.Count(segment, ")") +} diff --git a/repository/shape/dql/optimize/optimizer_test.go b/repository/shape/dql/optimize/optimizer_test.go new file mode 100644 index 000000000..ed1d3ac79 --- /dev/null +++ b/repository/shape/dql/optimize/optimizer_test.go @@ -0,0 +1,40 @@ +package optimize + +import ( + "strings" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + dqldiag "github.com/viant/datly/repository/shape/dql/diag" +) + +func TestRewrite_SelectorInterpolationReportsPosition(t *testing.T) { + input := "SELECT id FROM orders WHERE id = $Unsafe.Id" + _, diagnostics := Rewrite(input) + require.NotEmpty(t, diagnostics) + diag := diagnostics[0] + assert.Equal(t, dqldiag.CodeSQLIRawSelector, diag.Code) + assert.Equal(t, 1, diag.Span.Start.Line) + assert.Greater(t, diag.Span.Start.Char, 1) +} + +func TestRewrite_ParseFailureFallsBack(t *testing.T) { + input := "#if(true)" + out, diagnostics := Rewrite(input) + assert.Equal(t, input, out) + require.NotEmpty(t, diagnostics) + assert.Equal(t, dqldiag.CodeOptParse, diagnostics[len(diagnostics)-1].Code) + assert.True(t, strings.Contains(strings.ToLower(diagnostics[len(diagnostics)-1].Message), "optimization pass")) +} + +func TestRewrite_NopDoesNotReportSelectorInterpolation(t *testing.T) { + input := "SELECT 1 WHERE 1=1 $Nop($Unsafe.Id)" + _, diagnostics := Rewrite(input) + for _, item := range diagnostics { + if item == nil { + continue + } + assert.NotEqual(t, dqldiag.CodeSQLIRawSelector, item.Code) + } +} diff --git a/repository/shape/dql/parity/adorder_parity_test.go b/repository/shape/dql/parity/adorder_parity_test.go new file mode 100644 index 000000000..667a6c79d --- /dev/null +++ b/repository/shape/dql/parity/adorder_parity_test.go @@ -0,0 +1,92 @@ +package parity + +import ( + "context" + "os" + "strings" + "testing" + + dqlplan "github.com/viant/datly/repository/shape/dql/plan" + dqlyaml "github.com/viant/datly/repository/shape/dql/render/yaml" + dqlscan "github.com/viant/datly/repository/shape/dql/scan" +) + +func TestAdorderDQL_CanonicalParityWithYAML(t *testing.T) { + if os.Getenv("DATLY_RUN_ADORDER_PARITY") != "1" { + t.Skip("set DATLY_RUN_ADORDER_PARITY=1 to run adorder parity suite") + } + dqlPath := "/Users/adrianwitas/Downloads/pp/dql/platform/adorder/adorder.dql" + yamlPath := "/Users/adrianwitas/Downloads/pp/repo/dev/Datly/routes/platform/adorder/adorder.yaml" + repoPath := "/Users/adrianwitas/Downloads/pp/repo/dev" + + if _, err := os.Stat(dqlPath); err != nil { + t.Skipf("missing fixture dql file: %v", err) + } + if _, err := os.Stat(yamlPath); err != nil { + t.Skipf("missing fixture yaml file: %v", err) + } + + scanner := dqlscan.New() + connectors := resolveConnectors([]string{ + "ci_ads|mysql|root:dev@tcp(127.0.0.1:3307)/ci_ads?parseTime=true&charset=utf8mb4&collation=utf8mb4_bin", + "ci_logs|mysql|root:dev@tcp(127.0.0.1:3307)/ci_logs?parseTime=true", + }) + scanned, err := scanner.Scan(context.Background(), &dqlscan.Request{ + DQLURL: dqlPath, + Repository: repoPath, + ModulePrefix: "platform/adorder", + APIPrefix: "/v1/api", + Connectors: connectors, + }) + if err != nil { + if strings.Contains(err.Error(), "Unknown database") || strings.Contains(err.Error(), "failed to discover/detect column") { + t.Skipf("environment not ready for parity scan: %v", err) + } + t.Fatalf("scan failed: %v", err) + } + + fromDQL, err := dqlplan.BuildFromIR(scanned.IR) + if err != nil { + t.Fatalf("plan from dql failed: %v", err) + } + + yamlData, err := os.ReadFile(yamlPath) + if err != nil { + t.Fatalf("read yaml failed: %v", err) + } + fromYAML, err := dqlplan.Build(yamlData) + if err != nil { + t.Fatalf("plan from yaml failed: %v", err) + } + issues := Diff(fromDQL.Canonical, fromYAML.Canonical) + if len(issues) > 0 { + max := len(issues) + if max > 30 { + max = 30 + } + for i := 0; i < max; i++ { + t.Log(issues[i]) + } + t.Fatalf("canonical diff detected: %d issues", len(issues)) + } + + renderedYAML, err := dqlyaml.Encode(scanned.IR) + if err != nil { + t.Fatalf("render yaml from IR failed: %v", err) + } + fromRendered, err := dqlplan.Build(renderedYAML) + if err != nil { + t.Fatalf("plan from rendered yaml failed: %v", err) + } + roundTripIssues := Diff(fromRendered.Canonical, fromYAML.Canonical) + if len(roundTripIssues) > 0 { + max := len(roundTripIssues) + if max > 30 { + max = 30 + } + for i := 0; i < max; i++ { + t.Log(roundTripIssues[i]) + } + t.Fatalf("ir->yaml canonical diff detected: %d issues", len(roundTripIssues)) + } +} diff --git a/repository/shape/dql/parity/connectors.go b/repository/shape/dql/parity/connectors.go new file mode 100644 index 000000000..eaaacab8b --- /dev/null +++ b/repository/shape/dql/parity/connectors.go @@ -0,0 +1,28 @@ +package parity + +import ( + "fmt" + "os" + "strings" +) + +// resolveConnectors returns connectors from env override, or defaults. +// When DATLY_PARITY_SQLITE_DSN is set, all default connector names are mapped to sqlite3. +func resolveConnectors(defaults []string) []string { + if override := splitNonEmpty(os.Getenv("DATLY_PARITY_CONNECTORS")); len(override) > 0 { + return override + } + sqliteDSN := strings.TrimSpace(os.Getenv("DATLY_PARITY_SQLITE_DSN")) + if sqliteDSN == "" { + return defaults + } + ret := make([]string, 0, len(defaults)) + for _, item := range defaults { + parts := strings.Split(item, "|") + if len(parts) == 0 || strings.TrimSpace(parts[0]) == "" { + continue + } + ret = append(ret, fmt.Sprintf("%s|sqlite3|%s", strings.TrimSpace(parts[0]), sqliteDSN)) + } + return ret +} diff --git a/repository/shape/dql/parity/diff.go b/repository/shape/dql/parity/diff.go new file mode 100644 index 000000000..32b77ff53 --- /dev/null +++ b/repository/shape/dql/parity/diff.go @@ -0,0 +1,59 @@ +package parity + +import ( + "fmt" + "reflect" + "sort" +) + +// Diff compares two canonical maps and returns human-readable mismatches. +func Diff(a, b map[string]any) []string { + var issues []string + diffValue("$", a, b, &issues) + sort.Strings(issues) + return issues +} + +func diffValue(path string, a, b any, issues *[]string) { + if a == nil && b == nil { + return + } + if a == nil || b == nil { + *issues = append(*issues, fmt.Sprintf("%s: one side is nil", path)) + return + } + if reflect.TypeOf(a) != reflect.TypeOf(b) { + *issues = append(*issues, fmt.Sprintf("%s: type mismatch %T != %T", path, a, b)) + return + } + switch av := a.(type) { + case map[string]any: + bv := b.(map[string]any) + for k, v := range av { + bvItem, ok := bv[k] + if !ok { + *issues = append(*issues, fmt.Sprintf("%s.%s: missing in rhs", path, k)) + continue + } + diffValue(path+"."+k, v, bvItem, issues) + } + for k := range bv { + if _, ok := av[k]; !ok { + *issues = append(*issues, fmt.Sprintf("%s.%s: missing in lhs", path, k)) + } + } + case []any: + bv := b.([]any) + if len(av) != len(bv) { + *issues = append(*issues, fmt.Sprintf("%s: len mismatch %d != %d", path, len(av), len(bv))) + return + } + for i := range av { + diffValue(fmt.Sprintf("%s[%d]", path, i), av[i], bv[i], issues) + } + default: + if !reflect.DeepEqual(a, b) { + *issues = append(*issues, fmt.Sprintf("%s: value mismatch %v != %v", path, a, b)) + } + } +} diff --git a/repository/shape/dql/parity/mdp_parity_test.go b/repository/shape/dql/parity/mdp_parity_test.go new file mode 100644 index 000000000..6941ee914 --- /dev/null +++ b/repository/shape/dql/parity/mdp_parity_test.go @@ -0,0 +1,160 @@ +package parity + +import ( + "context" + "os" + "path/filepath" + "strings" + "testing" + + dqlplan "github.com/viant/datly/repository/shape/dql/plan" + dqlscan "github.com/viant/datly/repository/shape/dql/scan" +) + +func TestMDPDQL_CanonicalParityWithRoutes(t *testing.T) { + if os.Getenv("DATLY_RUN_MDP_PARITY") != "1" { + t.Skip("set DATLY_RUN_MDP_PARITY=1 to run mdp parity suite") + } + mdpRoot := envOr("DATLY_MDP_ROOT", "/Users/adrianwitas/go/src/github.vianttech.com/viant/mdp") + repoRoot := envOr("DATLY_MDP_REPO", filepath.Join(mdpRoot, "repo", "dev")) + routesRoot := filepath.Join(repoRoot, "Datly", "routes", "mdp") + dqlRoot := filepath.Join(mdpRoot, "dql") + if _, err := os.Stat(routesRoot); err != nil { + t.Fatalf("routes root missing: %v", err) + } + if _, err := os.Stat(dqlRoot); err != nil { + t.Fatalf("dql root missing: %v", err) + } + + connectors := splitNonEmpty(os.Getenv("DATLY_MDP_CONNECTORS")) + if len(connectors) == 0 { + connectors = resolveConnectors([]string{ + "ci_ads|mysql|root:dev@tcp(127.0.0.1:3307)/ci_ads?parseTime=true&charset=utf8mb4&collation=utf8mb4_bin", + "ci_ads_rw|mysql|root:dev@tcp(127.0.0.1:3307)/ci_ads?parseTime=true&charset=utf8mb4&collation=utf8mb4_bin", + "bq_mdp|mysql|root:dev@tcp(127.0.0.1:3307)/ci_ads?parseTime=true&charset=utf8mb4&collation=utf8mb4_bin", + "bq_automation|mysql|root:dev@tcp(127.0.0.1:3307)/ci_ads?parseTime=true&charset=utf8mb4&collation=utf8mb4_bin", + }) + } + + type issue struct { + route string + msg string + } + var issues []issue + scanner := dqlscan.New() + _ = filepath.WalkDir(routesRoot, func(path string, d os.DirEntry, walkErr error) error { + if walkErr != nil || d.IsDir() { + return walkErr + } + if filepath.Ext(path) != ".yaml" { + return nil + } + base := filepath.Base(path) + if base == "producer.yaml" || strings.HasPrefix(path, filepath.Join(routesRoot, ".meta")) || strings.Contains(path, string(filepath.Separator)+".meta"+string(filepath.Separator)) { + return nil + } + rel, err := filepath.Rel(routesRoot, path) + if err != nil { + issues = append(issues, issue{route: path, msg: "failed to compute relative route path: " + err.Error()}) + return nil + } + ruleDir := filepath.Dir(rel) + ruleName := strings.TrimSuffix(filepath.Base(path), ".yaml") + dqlFile := filepath.Join(dqlRoot, ruleDir, ruleName+".dql") + if _, err = os.Stat(dqlFile); err != nil { + dqlFile = filepath.Join(dqlRoot, ruleDir, ruleName+".sql") + } + if _, err = os.Stat(dqlFile); err != nil { + t.Logf("skip %s: missing dql/sql counterpart", path) + return nil + } + modulePrefix := filepath.ToSlash(filepath.Join("mdp", ruleDir)) + scanned, err := scanner.Scan(context.Background(), &dqlscan.Request{ + DQLURL: dqlFile, + Repository: repoRoot, + ModulePrefix: modulePrefix, + APIPrefix: "/v1/api", + Connectors: connectors, + }) + if err != nil { + if strings.Contains(err.Error(), "failed to parse import statement") { + t.Logf("skip %s: %v", path, err) + return nil + } + issues = append(issues, issue{route: path, msg: "scan failed: " + err.Error()}) + return nil + } + fromDQL, err := dqlplan.BuildFromIR(scanned.IR) + if err != nil { + issues = append(issues, issue{route: path, msg: "build from dql ir failed: " + err.Error()}) + return nil + } + yamlBytes, err := os.ReadFile(path) + if err != nil { + issues = append(issues, issue{route: path, msg: "read route yaml failed: " + err.Error()}) + return nil + } + fromYAML, err := dqlplan.Build(yamlBytes) + if err != nil { + issues = append(issues, issue{route: path, msg: "build from route yaml failed: " + err.Error()}) + return nil + } + normalizeMDPCanonical(fromDQL.Canonical) + normalizeMDPCanonical(fromYAML.Canonical) + diff := Diff(fromDQL.Canonical, fromYAML.Canonical) + if len(diff) > 0 { + msg := "canonical diff issues: " + diff[0] + issues = append(issues, issue{route: path, msg: msg}) + } + return nil + }) + + if len(issues) == 0 { + return + } + limit := len(issues) + if limit > 40 { + limit = 40 + } + for i := 0; i < limit; i++ { + t.Logf("%s => %s", issues[i].route, issues[i].msg) + } + t.Fatalf("mdp parity issues: %d", len(issues)) +} + +func normalizeMDPCanonical(canonical map[string]any) { + routes, ok := canonical["Routes"].([]any) + if !ok { + return + } + for _, routeItem := range routes { + route, ok := routeItem.(map[string]any) + if !ok { + continue + } + input, ok := route["Input"].(map[string]any) + if !ok { + continue + } + delete(input, "Parameters") + } +} + +func envOr(key, fallback string) string { + if value := strings.TrimSpace(os.Getenv(key)); value != "" { + return value + } + return fallback +} + +func splitNonEmpty(csv string) []string { + var ret []string + for _, item := range strings.Split(csv, ",") { + item = strings.TrimSpace(item) + if item == "" { + continue + } + ret = append(ret, item) + } + return ret +} diff --git a/repository/shape/dql/parse/function.go b/repository/shape/dql/parse/function.go new file mode 100644 index 000000000..dda7bb244 --- /dev/null +++ b/repository/shape/dql/parse/function.go @@ -0,0 +1,70 @@ +package parse + +import ( + "fmt" + "strings" +) + +// FunctionHandler handles parsed DQL function call. +type FunctionHandler interface { + Name() string + Handle(call *FunctionCall, result *Result) error +} + +// FunctionHandlerFunc adapts function to handler. +type FunctionHandlerFunc struct { + FunctionName string + Fn func(call *FunctionCall, result *Result) error +} + +func (f FunctionHandlerFunc) Name() string { + return strings.ToLower(strings.TrimSpace(f.FunctionName)) +} + +func (f FunctionHandlerFunc) Handle(call *FunctionCall, result *Result) error { + if f.Fn == nil { + return nil + } + return f.Fn(call, result) +} + +// FunctionRegistry stores handlers by function name. +type FunctionRegistry struct { + items map[string]FunctionHandler +} + +// NewFunctionRegistry creates function registry. +func NewFunctionRegistry(handlers ...FunctionHandler) *FunctionRegistry { + ret := &FunctionRegistry{items: map[string]FunctionHandler{}} + for _, handler := range handlers { + ret.Register(handler) + } + return ret +} + +// Register registers function handler. +func (r *FunctionRegistry) Register(handler FunctionHandler) { + if r == nil || handler == nil { + return + } + name := strings.ToLower(strings.TrimSpace(handler.Name())) + if name == "" { + return + } + r.items[name] = handler +} + +func (r *FunctionRegistry) apply(call *FunctionCall, result *Result) error { + if r == nil || call == nil { + return nil + } + handler := r.items[strings.ToLower(call.Name)] + if handler == nil { + return nil + } + if err := handler.Handle(call, result); err != nil { + return fmt.Errorf("function %s failed: %w", call.Name, err) + } + call.Handled = true + return nil +} diff --git a/repository/shape/dql/parse/model.go b/repository/shape/dql/parse/model.go new file mode 100644 index 000000000..c1cf0a04f --- /dev/null +++ b/repository/shape/dql/parse/model.go @@ -0,0 +1,38 @@ +package parse + +import ( + "github.com/viant/datly/repository/shape/dql/decl" + "github.com/viant/datly/repository/shape/typectx" + "github.com/viant/sqlparser" + "github.com/viant/sqlparser/query" +) + +// Diagnostic describes parser issue with source position. +type Diagnostic struct { + Stage string + Message string + Offset int + Line int + Column int +} + +// FunctionCall captures declaration function invocation. +type FunctionCall struct { + Name string + Args []string + Raw string + Offset int + Line int + Column int + Handled bool +} + +// Result is parser output. +type Result struct { + Query *query.Select + Columns sqlparser.Columns + Declarations []*decl.Declaration + TypeContext *typectx.Context + Functions []*FunctionCall + Diagnostics []*Diagnostic +} diff --git a/repository/shape/dql/parse/options.go b/repository/shape/dql/parse/options.go new file mode 100644 index 000000000..640a0e6eb --- /dev/null +++ b/repository/shape/dql/parse/options.go @@ -0,0 +1,40 @@ +package parse + +type ( + UnknownNonReadMode string + + Options struct { + UnknownNonReadMode UnknownNonReadMode + } + + Option func(*Options) +) + +const ( + UnknownNonReadModeWarn UnknownNonReadMode = "warn" + UnknownNonReadModeError UnknownNonReadMode = "error" +) + +func WithUnknownNonReadMode(mode UnknownNonReadMode) Option { + return func(o *Options) { + if o == nil { + return + } + o.UnknownNonReadMode = mode + } +} + +func defaultOptions() Options { + return Options{ + UnknownNonReadMode: UnknownNonReadModeWarn, + } +} + +func normalizeUnknownNonReadMode(mode UnknownNonReadMode) UnknownNonReadMode { + switch mode { + case UnknownNonReadModeWarn, UnknownNonReadModeError: + return mode + default: + return UnknownNonReadModeWarn + } +} diff --git a/repository/shape/dql/parse/parser.go b/repository/shape/dql/parse/parser.go new file mode 100644 index 000000000..cb8d81974 --- /dev/null +++ b/repository/shape/dql/parse/parser.go @@ -0,0 +1,152 @@ +package parse + +import ( + "errors" + "strings" + + dqldiag "github.com/viant/datly/repository/shape/dql/diag" + dqlpre "github.com/viant/datly/repository/shape/dql/preprocess" + "github.com/viant/datly/repository/shape/dql/shape" + dqlstmt "github.com/viant/datly/repository/shape/dql/statement" + "github.com/viant/parsly" + "github.com/viant/sqlparser" + "github.com/viant/sqlparser/query" +) + +// Parser parses DQL source into a shape Document. +type Parser struct { + options Options +} + +// New creates a DQL parser. +func New(opts ...Option) *Parser { + options := defaultOptions() + for _, opt := range opts { + if opt != nil { + opt(&options) + } + } + options.UnknownNonReadMode = normalizeUnknownNonReadMode(options.UnknownNonReadMode) + return &Parser{options: options} +} + +// Parse parses DQL and returns parsed document with diagnostics. +func (p *Parser) Parse(dql string) (*shape.Document, error) { + doc := &shape.Document{Raw: dql} + sql, ctx, directives, directiveDiagnostics := dqlpre.Extract(dql) + doc.SQL = strings.TrimSpace(sql) + doc.TypeContext = ctx + doc.Directives = directives + if len(directiveDiagnostics) > 0 { + doc.Diagnostics = append(doc.Diagnostics, directiveDiagnostics...) + for _, diagnostic := range directiveDiagnostics { + if diagnostic != nil && diagnostic.Severity == shape.SeverityError { + return doc, diagnostic + } + } + } + + if doc.SQL == "" { + d := &shape.Diagnostic{ + Code: dqldiag.CodeParseEmpty, + Severity: shape.SeverityError, + Message: "no SQL statement found", + Hint: "add SELECT/INSERT/UPDATE/DELETE statement after DQL directives", + Span: dqlpre.PointSpan(dql, 0), + } + doc.Diagnostics = append(doc.Diagnostics, d) + return doc, d + } + + statements := dqlstmt.New(sql) + readStmt := firstReadStatement(statements) + if readStmt == nil { + if !hasExecStatement(statements) { + severity := shape.SeverityWarning + if p.options.UnknownNonReadMode == UnknownNonReadModeError { + severity = shape.SeverityError + } + doc.Diagnostics = append(doc.Diagnostics, &shape.Diagnostic{ + Code: dqldiag.CodeParseUnknownNonRead, + Severity: severity, + Message: "no readable SELECT statement detected", + Hint: "use SELECT for read parsing or compile as DML/handler template", + Span: dqlpre.PointSpan(dql, 0), + }) + if severity == shape.SeverityError { + return doc, doc.Diagnostics[len(doc.Diagnostics)-1] + } + } + // DML-only statement sets are valid for parse contract. + return doc, nil + } + querySQL := sql[readStmt.Start:readStmt.End] + queryNode, diag, err := parseQueryWithDiagnosticAt(querySQL, dql, readStmt.Start) + if diag != nil { + doc.Diagnostics = append(doc.Diagnostics, diag) + } + if err != nil { + return doc, diag + } + doc.Query = queryNode + return doc, nil +} + +func firstReadStatement(statements dqlstmt.Statements) *dqlstmt.Statement { + for _, stmt := range statements { + if stmt == nil { + continue + } + if stmt.Kind == dqlstmt.KindRead { + return stmt + } + } + return nil +} + +func hasExecStatement(statements dqlstmt.Statements) bool { + for _, stmt := range statements { + if stmt != nil && stmt.IsExec { + return true + } + } + return false +} + +func parseQueryWithDiagnosticAt(sqlText, original string, baseOffset int) (*query.Select, *shape.Diagnostic, error) { + cursor := parsly.NewCursor("", []byte(sqlText), 0) + var diagnostic *shape.Diagnostic + cursor.OnError = func(err error, cur *parsly.Cursor, _ interface{}) error { + offset := 0 + if cur != nil { + offset = cur.Pos + } + if offset < 0 { + offset = 0 + } + offset += baseOffset + diagnostic = &shape.Diagnostic{ + Code: dqldiag.CodeParseSyntax, + Severity: shape.SeverityError, + Message: strings.TrimSpace(err.Error()), + Hint: "check SQL syntax near the reported location", + Span: dqlpre.PointSpan(original, offset), + } + return err + } + result := &query.Select{} + err := sqlparser.Parse(cursor, result) + if err != nil { + if diagnostic == nil { + diagnostic = &shape.Diagnostic{ + Code: dqldiag.CodeParseSyntax, + Severity: shape.SeverityError, + Message: strings.TrimSpace(err.Error()), + Hint: "check SQL syntax near the reported location", + Span: dqlpre.PointSpan(original, baseOffset), + } + } + return nil, diagnostic, errors.New(diagnostic.Error()) + } + return result, nil, nil +} diff --git a/repository/shape/dql/parse/parser_test.go b/repository/shape/dql/parse/parser_test.go new file mode 100644 index 000000000..6a3776dc2 --- /dev/null +++ b/repository/shape/dql/parse/parser_test.go @@ -0,0 +1,134 @@ +package parse + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + dqldiag "github.com/viant/datly/repository/shape/dql/diag" + dqlshape "github.com/viant/datly/repository/shape/dql/shape" +) + +func TestParser_Parse_TypeContext(t *testing.T) { + dql := ` +#package('mdp/performance') +#import('perf', 'github.com/acme/mdp/performance') +SELECT id FROM ORDERS t +` + parsed, err := New().Parse(dql) + require.NoError(t, err) + require.NotNil(t, parsed) + require.NotNil(t, parsed.TypeContext) + assert.Equal(t, "mdp/performance", parsed.TypeContext.DefaultPackage) + require.Len(t, parsed.TypeContext.Imports, 1) + assert.Equal(t, "perf", parsed.TypeContext.Imports[0].Alias) + assert.Equal(t, "github.com/acme/mdp/performance", parsed.TypeContext.Imports[0].Package) +} + +func TestParser_Parse_SpecialDirectives(t *testing.T) { + dql := ` +#settings($_ = $meta('docs/orders.md')) +#setting($_ = $connector('analytics')) +#settings($_ = $cache(true, '5m')) +#settings($_ = $mcp('orders.search', 'Search orders', 'docs/mcp/orders.md')) +SELECT id FROM ORDERS t +` + parsed, err := New().Parse(dql) + require.NoError(t, err) + require.NotNil(t, parsed) + require.NotNil(t, parsed.Directives) + assert.Equal(t, "docs/orders.md", parsed.Directives.Meta) + assert.Equal(t, "analytics", parsed.Directives.DefaultConnector) + require.NotNil(t, parsed.Directives.Cache) + assert.True(t, parsed.Directives.Cache.Enabled) + assert.Equal(t, "5m", parsed.Directives.Cache.TTL) + require.NotNil(t, parsed.Directives.MCP) + assert.Equal(t, "orders.search", parsed.Directives.MCP.Name) + assert.Equal(t, "Search orders", parsed.Directives.MCP.Description) + assert.Equal(t, "docs/mcp/orders.md", parsed.Directives.MCP.DescriptionPath) +} + +func TestParser_Parse_SyntaxErrorPosition(t *testing.T) { + dql := "SELECT id FROM ORDERS WHERE (" + parsed, err := New().Parse(dql) + require.Error(t, err) + require.NotNil(t, parsed) + require.NotEmpty(t, parsed.Diagnostics) + diag := parsed.Diagnostics[0] + assert.Equal(t, dqldiag.CodeParseSyntax, diag.Code) + assert.Equal(t, 1, diag.Span.Start.Line) + assert.Equal(t, 29, diag.Span.Start.Char) +} + +func TestParser_Parse_OnlyDirectives(t *testing.T) { + dql := "#package('x')\n#import('a','b')" + parsed, err := New().Parse(dql) + require.Error(t, err) + require.NotNil(t, parsed) + require.NotEmpty(t, parsed.Diagnostics) + assert.Equal(t, dqldiag.CodeParseEmpty, parsed.Diagnostics[0].Code) + assert.Equal(t, 1, parsed.Diagnostics[0].Span.Start.Line) + assert.Equal(t, 1, parsed.Diagnostics[0].Span.Start.Char) +} + +func TestParser_Parse_InvalidDirective_HasLineAndChar(t *testing.T) { + dql := "SELECT id FROM ORDERS t\n#import('alias')\nSELECT id FROM ORDERS t" + parsed, err := New().Parse(dql) + require.Error(t, err) + require.NotNil(t, parsed) + require.NotEmpty(t, parsed.Diagnostics) + diag := parsed.Diagnostics[0] + assert.Equal(t, dqldiag.CodeDirImport, diag.Code) + assert.Equal(t, 2, diag.Span.Start.Line) + assert.Equal(t, 1, diag.Span.Start.Char) +} + +func TestParser_Parse_DMLOnly_NoError(t *testing.T) { + dql := "INSERT INTO ORDERS(id) VALUES (1)" + parsed, err := New().Parse(dql) + require.NoError(t, err) + require.NotNil(t, parsed) + assert.Nil(t, parsed.Query) + assert.Empty(t, parsed.Diagnostics) +} + +func TestParser_Parse_Mixed_ReadAndExec_ParsesRead(t *testing.T) { + dql := "INSERT INTO ORDERS(id) VALUES (1)\nSELECT id FROM ORDERS t" + parsed, err := New().Parse(dql) + require.NoError(t, err) + require.NotNil(t, parsed) + require.NotNil(t, parsed.Query) + assert.Equal(t, "t", parsed.Query.From.Alias) +} + +func TestParser_Parse_UnknownNonRead_Warns(t *testing.T) { + dql := "$Foo.Bar($x)" + parsed, err := New().Parse(dql) + require.NoError(t, err) + require.NotNil(t, parsed) + assert.Nil(t, parsed.Query) + require.NotEmpty(t, parsed.Diagnostics) + assert.Equal(t, dqldiag.CodeParseUnknownNonRead, parsed.Diagnostics[len(parsed.Diagnostics)-1].Code) + assert.Equal(t, dqlshape.SeverityWarning, parsed.Diagnostics[len(parsed.Diagnostics)-1].Severity) +} + +func TestParser_Parse_UnknownNonRead_ErrorsWhenConfigured(t *testing.T) { + dql := "$Foo.Bar($x)" + parsed, err := New(WithUnknownNonReadMode(UnknownNonReadModeError)).Parse(dql) + require.Error(t, err) + require.NotNil(t, parsed) + assert.Nil(t, parsed.Query) + require.NotEmpty(t, parsed.Diagnostics) + assert.Equal(t, dqldiag.CodeParseUnknownNonRead, parsed.Diagnostics[len(parsed.Diagnostics)-1].Code) + assert.Equal(t, dqlshape.SeverityError, parsed.Diagnostics[len(parsed.Diagnostics)-1].Severity) +} + +func TestParser_Parse_UnknownNonRead_InvalidModeDefaultsToWarn(t *testing.T) { + dql := "$Foo.Bar($x)" + parsed, err := New(WithUnknownNonReadMode(UnknownNonReadMode("invalid"))).Parse(dql) + require.NoError(t, err) + require.NotNil(t, parsed) + require.NotEmpty(t, parsed.Diagnostics) + assert.Equal(t, dqldiag.CodeParseUnknownNonRead, parsed.Diagnostics[len(parsed.Diagnostics)-1].Code) + assert.Equal(t, dqlshape.SeverityWarning, parsed.Diagnostics[len(parsed.Diagnostics)-1].Severity) +} diff --git a/repository/shape/dql/plan/planner.go b/repository/shape/dql/plan/planner.go new file mode 100644 index 000000000..9b81f6a24 --- /dev/null +++ b/repository/shape/dql/plan/planner.go @@ -0,0 +1,609 @@ +package plan + +import ( + "fmt" + "github.com/viant/datly/repository/shape/dql/ir" + "gopkg.in/yaml.v3" + "regexp" + "sort" + "strings" +) + +var ( + routeFields = []string{"Name", "URI", "Method", "Description", "MCPTool", "Service"} + routeInputFields = []string{"Type", "Parameters"} + routeOutputFields = []string{"Type", "Parameters", "Exclude", "CaseFormat", "Tag"} + parameterFields = []string{"Name", "Required", "Tag", "ErrorStatusCode", "Cacheable", "Scope", "Connector", "Value", "Limit"} + viewFields = []string{"Name", "Table", "Mode", "AllowNulls", "RelationalConcurrency"} + selectorFields = []string{"Constraints", "Limit", "Namespace"} + templateFields = []string{"SourceURL", "Source", "Summary"} +) + +var tagMatcher = regexp.MustCompile(`([A-Za-z0-9_]+):"([^"]*)"`) +var veltyPlaceholderBraced = regexp.MustCompile(`\$\{([A-Za-z_][A-Za-z0-9_]*)\}`) + +// Result is canonicalized route YAML representation. +type Result struct { + Canonical map[string]any +} + +// Build creates a canonical map from route YAML. +func Build(routeYAML []byte) (*Result, error) { + if len(routeYAML) == 0 { + return nil, fmt.Errorf("dql plan: empty YAML") + } + var root map[string]any + if err := yaml.Unmarshal(routeYAML, &root); err != nil { + return nil, err + } + canonical := projectCanonical(root) + return &Result{Canonical: canonical}, nil +} + +// BuildFromIR canonicalizes IR without requiring YAML rendering/parsing. +func BuildFromIR(doc *ir.Document) (*Result, error) { + if doc == nil || doc.Root == nil { + return nil, fmt.Errorf("dql plan: empty IR") + } + canonical := projectCanonical(doc.Root) + return &Result{Canonical: canonical}, nil +} + +func projectCanonical(root map[string]any) map[string]any { + out := map[string]any{} + rootRefs := collectRootViewRefs(root["Routes"]) + if routes, ok := root["Routes"]; ok { + if canonical := canonicalRoutes(routes); len(canonical) > 0 { + out["Routes"] = canonical + } + } + if resource := toFlatMap(root["Resource"]); resource != nil { + if views := canonicalViews(resource["Views"], rootRefs); len(views) > 0 { + out["Resource"] = map[string]any{"Views": views} + } + } + return out +} + +func canonicalRoutes(raw any) []any { + items := canonicalSlice(raw) + var routes []map[string]any + for _, item := range items { + if normalized := toFlatMap(item); normalized != nil { + routes = append(routes, canonicalRoute(normalized)) + } + } + sort.SliceStable(routes, func(i, j int) bool { + return mapStringCompare(routes[i], routes[j], "Name", "URI") + }) + result := make([]any, len(routes)) + for i, r := range routes { + result[i] = r + } + return result +} + +func canonicalRoute(src map[string]any) map[string]any { + out := map[string]any{} + copyFields(out, src, routeFields) + if view := canonicalRouteView(src["View"]); view != nil { + out["View"] = view + } + if input := canonicalRouteIO(src["Input"], routeInputFields, true); len(input) > 0 { + out["Input"] = input + } + if output := canonicalRouteIO(src["Output"], routeOutputFields, false); len(output) > 0 { + out["Output"] = output + } + if with := canonicalStringList(src["With"]); len(with) > 0 { + out["With"] = with + } + return out +} + +func canonicalRouteView(raw any) map[string]any { + if normalized := toFlatMap(raw); normalized != nil { + return filterMap(normalized, []string{"Ref"}) + } + return nil +} + +func canonicalRouteIO(raw any, allowed []string, includeTypeName bool) map[string]any { + if normalized := toFlatMap(raw); normalized != nil { + out := map[string]any{} + typeMap := toFlatMap(normalized["Type"]) + for _, key := range allowed { + val, ok := normalized[key] + if !ok { + if key != "Parameters" { + continue + } + } + switch key { + case "Type": + if canonical := canonicalTypeWithName(typeMap, includeTypeName); len(canonical) > 0 { + out["Type"] = canonical + } + case "Parameters": + parameterRaw := val + if typeMap != nil && typeMap["Parameters"] != nil { + parameterRaw = typeMap["Parameters"] + } + if canonical := canonicalParameters(parameterRaw); len(canonical) > 0 { + out["Parameters"] = canonical + } + default: + out[key] = normalizeValue(val) + } + } + return out + } + return nil +} + +func canonicalTypeWithName(raw any, includeName bool) map[string]any { + keys := []string{"Package"} + if includeName { + keys = []string{"Name", "Package"} + } + return filterMap(toFlatMap(raw), keys) +} + +func canonicalParameters(raw any) []any { + items := canonicalSlice(raw) + var params []map[string]any + for _, item := range items { + if normalized := toFlatMap(item); normalized != nil { + if param := canonicalParameter(normalized); len(param) > 0 { + params = append(params, param) + } + } + } + sort.SliceStable(params, func(i, j int) bool { + return mapStringCompare(params[i], params[j], "Name") + }) + result := make([]any, len(params)) + for i, p := range params { + result[i] = p + } + return result +} + +func canonicalParameter(src map[string]any) map[string]any { + if in := canonicalIn(src["In"]); len(in) > 0 && fmt.Sprint(in["Kind"]) == "component" { + return nil + } else if isSyntheticSubstituteParameter(src, in) { + return nil + } + out := map[string]any{} + copyFields(out, src, parameterFields) + if tag := canonicalTag(src["Tag"]); len(tag) > 0 { + out["TagMeta"] = tag + } + if in := canonicalIn(src["In"]); len(in) > 0 { + out["In"] = in + } + if schema := canonicalSchema(src["Schema"]); len(schema) > 0 { + out["Schema"] = schema + } + if output := canonicalOutput(src["Output"]); len(output) > 0 { + out["Output"] = output + } + if preds := canonicalPredicates(src["Predicates"]); len(preds) > 0 { + out["Predicates"] = preds + } + if loc := canonicalLocationInput(src["LocationInput"]); len(loc) > 0 { + out["LocationInput"] = loc + } + return out +} + +func isSyntheticSubstituteParameter(src map[string]any, in map[string]any) bool { + if strings.ToLower(fmt.Sprint(in["Kind"])) != "form" { + return false + } + name := strings.ToLower(fmt.Sprint(src["Name"])) + return strings.HasSuffix(name, "_table_suffix") +} + +func canonicalIn(raw any) map[string]any { + return filterMap(toFlatMap(raw), []string{"Kind", "Name"}) +} + +func canonicalSchema(raw any) map[string]any { + out := filterMap(toFlatMap(raw), []string{"Name", "Package", "DataType", "Cardinality"}) + if pkg, ok := out["Package"].(string); ok { + out["Package"] = normalizeSchemaPackage(pkg) + } + return out +} + +func canonicalPredicates(raw any) []any { + items := canonicalSlice(raw) + var preds []map[string]any + for _, item := range items { + if normalized := toFlatMap(item); normalized != nil { + entry := filterMap(normalized, []string{"Name", "Ensure", "Group"}) + if args, ok := normalized["Args"]; ok { + entry["Args"] = normalizeValue(args) + } + if len(entry) > 0 { + preds = append(preds, entry) + } + } + } + sort.SliceStable(preds, func(i, j int) bool { + return mapStringCompare(preds[i], preds[j], "Name") + }) + result := make([]any, len(preds)) + for i, p := range preds { + result[i] = p + } + return result +} + +func canonicalLocationInput(raw any) map[string]any { + if normalized := toFlatMap(raw); normalized != nil { + out := map[string]any{} + copyFields(out, normalized, []string{"Name", "Package"}) + if params := canonicalParameters(normalized["Parameters"]); len(params) > 0 { + out["Parameters"] = params + } + return out + } + return nil +} + +func canonicalOutput(raw any) map[string]any { + if normalized := toFlatMap(raw); normalized != nil { + out := filterMap(normalized, []string{"Name", "Args"}) + if schema := canonicalSchema(normalized["Schema"]); len(schema) > 0 { + out["Schema"] = schema + } + return out + } + return nil +} + +func canonicalTag(raw any) map[string]any { + text := fmt.Sprint(raw) + if text == "" || text == "" { + return nil + } + parsed := map[string]string{} + for _, group := range tagMatcher.FindAllStringSubmatch(text, -1) { + if len(group) < 3 { + continue + } + parsed[group[1]] = group[2] + } + if len(parsed) == 0 { + return map[string]any{"Raw": text} + } + return map[string]any{ + "Raw": text, + "Pairs": parsed, + } +} + +func canonicalViews(raw any, roots []string) []any { + items := canonicalSlice(raw) + allowed := collectReachableViews(items, roots) + var views []map[string]any + for _, item := range items { + if normalized := toFlatMap(item); normalized != nil { + name := fmt.Sprint(normalized["Name"]) + if len(allowed) > 0 && !allowed[name] { + continue + } + view := canonicalView(normalized) + if len(view) > 0 { + views = append(views, view) + } + } + } + sort.SliceStable(views, func(i, j int) bool { + return mapStringCompare(views[i], views[j], "Name") + }) + result := make([]any, len(views)) + for i, v := range views { + result[i] = v + } + return result +} + +func canonicalView(src map[string]any) map[string]any { + out := map[string]any{} + copyFields(out, src, viewFields) + if partitioned := canonicalPartitioned(src["Partitioned"]); len(partitioned) > 0 { + out["Partitioned"] = partitioned + } + if connector := canonicalConnector(src["Connector"]); len(connector) > 0 { + out["Connector"] = connector + } + if selector := canonicalSelector(src["Selector"]); len(selector) > 0 { + out["Selector"] = selector + } + if strings.ToLower(fmt.Sprint(src["Mode"])) != "sqlexec" { + if template := canonicalTemplate(src["Template"]); len(template) > 0 { + out["Template"] = template + } + } + return out +} + +func canonicalPartitioned(raw any) map[string]any { + return filterMap(toFlatMap(raw), []string{"DataType", "Concurrency"}) +} + +func canonicalConnector(raw any) map[string]any { + return filterMap(toFlatMap(raw), []string{"Ref"}) +} + +func canonicalSelector(raw any) map[string]any { + if normalized := toFlatMap(raw); normalized != nil { + out := map[string]any{} + copyFields(out, normalized, selectorFields) + if constraints := canonicalSelectorConstraints(normalized["Constraints"]); len(constraints) > 0 { + out["Constraints"] = constraints + } + return out + } + return nil +} + +func canonicalSelectorConstraints(raw any) map[string]any { + return filterMap(toFlatMap(raw), []string{"Criteria", "Filterable", "Limit", "Offset", "OrderBy", "Projection"}) +} + +func canonicalTemplate(raw any) map[string]any { + if normalized := toFlatMap(raw); normalized != nil { + out := map[string]any{} + copyFields(out, normalized, templateFields) + if summary := canonicalSummary(normalized["Summary"]); len(summary) > 0 { + out["Summary"] = summary + } + if with := canonicalTemplateWith(normalized["With"]); len(with) > 0 { + out["With"] = with + } + return out + } + return nil +} + +func canonicalSummary(raw any) map[string]any { + if normalized := toFlatMap(raw); normalized != nil { + out := copyMap(filterMap(normalized, []string{"Kind", "Name", "Source"})) + if schema := canonicalSummarySchema(normalized["Schema"]); len(schema) > 0 { + out["Schema"] = schema + } + return out + } + return nil +} + +func canonicalSummarySchema(raw any) map[string]any { + return filterMap(toFlatMap(raw), []string{"Name", "Package", "DataType"}) +} + +func canonicalTemplateWith(raw any) []any { + return canonicalWithList(raw) +} + +func canonicalViewWith(raw any) []any { + return canonicalWithList(raw) +} + +func canonicalWithList(raw any) []any { + items := canonicalSlice(raw) + var nodes []map[string]any + for _, item := range items { + if normalized := toFlatMap(item); normalized != nil { + if node := canonicalWithNode(normalized); len(node) > 0 { + nodes = append(nodes, node) + } + } + } + sort.SliceStable(nodes, func(i, j int) bool { + return mapStringCompare(nodes[i], nodes[j], "Name", "Holder") + }) + result := make([]any, len(nodes)) + for i, n := range nodes { + result[i] = n + } + return result +} + +func canonicalWithNode(src map[string]any) map[string]any { + out := map[string]any{} + copyFields(out, src, []string{"Name", "Holder", "Cardinality", "IncludeColumn"}) + if of := canonicalOf(src["Of"]); len(of) > 0 { + out["Of"] = of + } + if on := canonicalOn(src["On"]); len(on) > 0 { + out["On"] = on + } + return out +} + +func canonicalOf(raw any) map[string]any { + return filterMap(toFlatMap(raw), []string{"Name", "Ref"}) +} + +func canonicalOn(raw any) []any { + items := canonicalSlice(raw) + var list []map[string]any + for _, item := range items { + if normalized := toFlatMap(item); normalized != nil { + entry := filterMap(normalized, []string{"Column", "Field"}) + if len(entry) > 0 { + list = append(list, entry) + } + } + } + sort.SliceStable(list, func(i, j int) bool { + return mapStringCompare(list[i], list[j], "Column", "Field") + }) + result := make([]any, len(list)) + for i, n := range list { + result[i] = n + } + return result +} + +func canonicalSlice(raw any) []any { + if normalized, ok := normalizeValue(raw).([]any); ok { + return normalized + } + return nil +} + +func toFlatMap(raw any) map[string]any { + if normalized, ok := normalizeValue(raw).(map[string]any); ok { + return normalized + } + return nil +} + +func normalizeValue(v any) any { + switch actual := v.(type) { + case map[string]any: + ret := map[string]any{} + for k, val := range actual { + ret[k] = normalizeValue(val) + } + return ret + case map[any]any: + ret := map[string]any{} + for k, val := range actual { + ret[fmt.Sprint(k)] = normalizeValue(val) + } + return ret + case []any: + ret := make([]any, len(actual)) + for i, item := range actual { + ret[i] = normalizeValue(item) + } + return ret + default: + if text, ok := actual.(string); ok { + return normalizeTextValue(text) + } + return actual + } +} + +func normalizeTextValue(text string) string { + if text == "" { + return text + } + return veltyPlaceholderBraced.ReplaceAllString(text, `$$$1`) +} + +func normalizeSchemaPackage(pkg string) string { + if pkg == "auto" { + return "automation" + } + if pkg == "allocator" { + return "bidalloc" + } + return pkg +} + +func copyFields(dst, src map[string]any, keys []string) { + for _, key := range keys { + if val, ok := src[key]; ok { + dst[key] = normalizeValue(val) + } + } +} + +func filterMap(src map[string]any, keys []string) map[string]any { + if src == nil { + return nil + } + out := map[string]any{} + for _, key := range keys { + if val, ok := src[key]; ok { + dst := normalizeValue(val) + if dst != nil { + out[key] = dst + } + } + } + return out +} + +func canonicalStringList(raw any) []string { + items := canonicalSlice(raw) + var list []string + for _, item := range items { + switch val := item.(type) { + case string: + list = append(list, val) + default: + list = append(list, fmt.Sprint(val)) + } + } + sort.Strings(list) + return list +} + +func mapStringCompare(a, b map[string]any, keys ...string) bool { + for _, key := range keys { + ai := fmt.Sprint(a[key]) + bi := fmt.Sprint(b[key]) + if ai != bi { + return ai < bi + } + } + return fmt.Sprint(a) < fmt.Sprint(b) +} + +func copyMap(src map[string]any) map[string]any { + if src == nil { + return nil + } + out := make(map[string]any, len(src)) + for k, v := range src { + out[k] = v + } + return out +} + +func collectRootViewRefs(raw any) []string { + items := canonicalSlice(raw) + unique := map[string]bool{} + var result []string + for _, item := range items { + route := toFlatMap(item) + if route == nil { + continue + } + view := toFlatMap(route["View"]) + if view == nil { + continue + } + ref := strings.TrimSpace(fmt.Sprint(view["Ref"])) + if ref == "" || unique[ref] { + continue + } + unique[ref] = true + result = append(result, ref) + } + sort.Strings(result) + return result +} + +func collectReachableViews(rawViews []any, roots []string) map[string]bool { + if len(roots) == 0 { + return nil + } + seen := map[string]bool{} + for _, root := range roots { + if root != "" { + seen[root] = true + } + } + return seen +} diff --git a/repository/shape/dql/plan/planner_test.go b/repository/shape/dql/plan/planner_test.go new file mode 100644 index 000000000..2a0046f0b --- /dev/null +++ b/repository/shape/dql/plan/planner_test.go @@ -0,0 +1,164 @@ +package plan + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestBuild_ProjectRouteTypeParametersWithTags(t *testing.T) { + yaml := ` +Routes: + - Name: Example + URI: /v1/api/example + Method: GET + Input: + Type: + Name: ExampleInput + Package: example + Parameters: + - Name: Auth + In: + Kind: component + Name: acl/auth + - Name: Id + Required: true + In: + Kind: query + Name: id + Tag: 'json:",omitempty" anonymous:"true"' + ErrorStatusCode: 401 + Cacheable: true + Scope: req + Connector: ci_ads + Limit: 25 + Schema: + DataType: int + Cardinality: One + Output: + Type: + Name: ExampleOutput + Package: example + Parameters: + - Name: Data + In: + Kind: output + Name: view + Output: + Name: Json + Args: ["a", "b"] + Schema: + DataType: string + Cardinality: One +` + result, err := Build([]byte(yaml)) + require.NoError(t, err) + require.NotNil(t, result) + + routes, ok := result.Canonical["Routes"].([]any) + require.True(t, ok) + require.Len(t, routes, 1) + + route, ok := routes[0].(map[string]any) + require.True(t, ok) + + input, ok := route["Input"].(map[string]any) + require.True(t, ok) + params, ok := input["Parameters"].([]any) + require.True(t, ok) + require.Len(t, params, 1, "component-kind parameter should be excluded from canonical input shape") + + param, ok := params[0].(map[string]any) + require.True(t, ok) + require.Equal(t, "Id", param["Name"]) + require.Equal(t, "json:\",omitempty\" anonymous:\"true\"", param["Tag"]) + require.EqualValues(t, 401, param["ErrorStatusCode"]) + require.Equal(t, true, param["Cacheable"]) + require.Equal(t, "req", param["Scope"]) + require.Equal(t, "ci_ads", param["Connector"]) + require.EqualValues(t, 25, param["Limit"]) + + tagMeta, ok := param["TagMeta"].(map[string]any) + require.True(t, ok) + require.Equal(t, "json:\",omitempty\" anonymous:\"true\"", tagMeta["Raw"]) + pairs, ok := tagMeta["Pairs"].(map[string]string) + require.True(t, ok) + require.Equal(t, ",omitempty", pairs["json"]) + require.Equal(t, "true", pairs["anonymous"]) + + output, ok := route["Output"].(map[string]any) + require.True(t, ok) + outParams, ok := output["Parameters"].([]any) + require.True(t, ok) + require.Len(t, outParams, 1) + outParam, ok := outParams[0].(map[string]any) + require.True(t, ok) + outMeta, ok := outParam["Output"].(map[string]any) + require.True(t, ok) + require.Equal(t, "Json", outMeta["Name"]) +} + +func TestValidateRelations_AliasAndColumnsWithLineDetails(t *testing.T) { + routeYAML := ` +Resource: + Views: + - Name: Parent + Template: + Source: |- + SELECT p.ID, p.CAMPAIGN_ID FROM CI_PARENT p + With: + - Name: campaign + Holder: Campaign + Cardinality: One + On: + - Column: MISSING_PARENT + Namespace: p + Of: + Ref: Child + On: + - Column: MISSING_CHILD + Namespace: missing_alias + - Name: Child + Template: + Source: |- + SELECT c.ID FROM CI_CHILD c +` + err := ValidateRelations([]byte(routeYAML)) + require.Error(t, err) + require.Contains(t, err.Error(), "dql plan relation validation failed") + require.Contains(t, err.Error(), "line") + require.Contains(t, err.Error(), "alias=\"missing_alias\"") + require.Contains(t, err.Error(), "column=\"MISSING_PARENT\"") + require.Contains(t, err.Error(), "column=\"MISSING_CHILD\"") + require.Contains(t, err.Error(), "column not projected") + require.Contains(t, err.Error(), "alias not present in SQL/selector namespace") +} + +func TestValidateRelations_AllowsValidRelationAliasAndColumns(t *testing.T) { + routeYAML := ` +Resource: + Views: + - Name: Parent + Template: + Source: |- + SELECT p.ID, p.CAMPAIGN_ID FROM CI_PARENT p + With: + - Name: campaign + Holder: Campaign + Cardinality: One + On: + - Column: CAMPAIGN_ID + Namespace: p + Of: + Ref: Child + On: + - Column: ID + Namespace: c + - Name: Child + Template: + Source: |- + SELECT c.ID FROM CI_CHILD c +` + err := ValidateRelations([]byte(routeYAML)) + require.NoError(t, err) +} diff --git a/repository/shape/dql/plan/relation_sql.go b/repository/shape/dql/plan/relation_sql.go new file mode 100644 index 000000000..5530b68b4 --- /dev/null +++ b/repository/shape/dql/plan/relation_sql.go @@ -0,0 +1,82 @@ +package plan + +import ( + "strings" + + "github.com/viant/sqlparser" + "github.com/viant/sqlparser/expr" + "github.com/viant/sqlparser/node" + "github.com/viant/sqlparser/query" +) + +func analyzeSQL(source string) (map[string]bool, projectionMeta, bool) { + aliases := map[string]bool{} + proj := projectionMeta{Columns: map[string]bool{}} + source = strings.TrimSpace(source) + if source == "" { + return aliases, proj, false + } + query, err := sqlparser.ParseQuery(source) + if err != nil || query == nil { + return aliases, proj, false + } + collectSQLAliases(query, aliases) + collectSQLProjection(query, &proj) + return aliases, proj, true +} + +func collectSQLAliases(query *query.Select, aliases map[string]bool) { + registerAlias(aliases, query.From.Alias) + registerFromNodeAlias(aliases, query.From.X) + for _, join := range query.Joins { + if join == nil { + continue + } + registerAlias(aliases, join.Alias) + registerFromNodeAlias(aliases, join.With) + } +} + +func collectSQLProjection(query *query.Select, projection *projectionMeta) { + columns := sqlparser.NewColumns(query.List) + projection.HasStar = columns.IsStarExpr() + for _, col := range columns { + if col == nil { + continue + } + registerProjection(projection.Columns, col.Name) + registerProjection(projection.Columns, col.Alias) + registerProjection(projection.Columns, col.Expression) + } +} + +func registerProjection(index map[string]bool, value string) { + value = strings.TrimSpace(value) + if value == "" || strings.Contains(value, "*") { + return + } + index[normalizedProjectionKey(value)] = true + if i := strings.LastIndex(value, "."); i != -1 && i+1 < len(value) { + suffix := strings.TrimSpace(value[i+1:]) + if suffix != "" { + index[normalizedProjectionKey(suffix)] = true + } + } +} + +func registerAlias(index map[string]bool, alias string) { + alias = strings.TrimSpace(alias) + if alias == "" { + return + } + index[strings.ToLower(alias)] = true +} + +func registerFromNodeAlias(index map[string]bool, n node.Node) { + switch actual := n.(type) { + case *expr.Ident: + registerAlias(index, actual.Name) + case *expr.Selector: + registerAlias(index, actual.Name) + } +} diff --git a/repository/shape/dql/plan/relation_types.go b/repository/shape/dql/plan/relation_types.go new file mode 100644 index 000000000..f5daa1879 --- /dev/null +++ b/repository/shape/dql/plan/relation_types.go @@ -0,0 +1,32 @@ +package plan + +type relationLink struct { + Line int + Column string + Namespace string +} + +type relationMeta struct { + Line int + Name string + Holder string + Ref string + On []relationLink + OfOn []relationLink + PairCount int +} + +type projectionMeta struct { + Columns map[string]bool + HasStar bool +} + +type viewMeta struct { + Name string + Line int + HasSQL bool + Aliases map[string]bool + Namespaces map[string]bool + Projection projectionMeta + Relations []relationMeta +} diff --git a/repository/shape/dql/plan/relation_validate.go b/repository/shape/dql/plan/relation_validate.go new file mode 100644 index 000000000..c0038c293 --- /dev/null +++ b/repository/shape/dql/plan/relation_validate.go @@ -0,0 +1,204 @@ +package plan + +import ( + "fmt" + "sort" + "strings" + + "gopkg.in/yaml.v3" +) + +// ValidateRelations validates relation links in generated route YAML. +func ValidateRelations(routeYAML []byte) error { + var root map[string]any + if err := yaml.Unmarshal(routeYAML, &root); err != nil { + return err + } + views := extractViews(root) + if len(views) == 0 { + return nil + } + lineIndex, err := collectViewMeta(routeYAML) + if err != nil { + return err + } + viewIndex := buildViewIndex(views, lineIndex) + issues := collectRelationIssues(viewIndex) + if len(issues) == 0 { + return nil + } + return fmt.Errorf("dql plan relation validation failed:\n- %s", strings.Join(issues, "\n- ")) +} + +func buildViewIndex(views []any, lineIndex map[string]*viewMeta) map[string]*viewMeta { + viewIndex := map[string]*viewMeta{} + for _, item := range views { + viewMap := toFlatMap(item) + if viewMap == nil { + continue + } + name := strings.TrimSpace(fmt.Sprint(viewMap["Name"])) + if name == "" { + continue + } + meta := lineIndex[name] + if meta == nil { + meta = &viewMeta{Name: name, Namespaces: map[string]bool{}} + lineIndex[name] = meta + } + applyViewRuntimeSQLMeta(viewMap, meta) + viewIndex[name] = meta + } + return viewIndex +} + +func applyViewRuntimeSQLMeta(viewMap map[string]any, meta *viewMeta) { + template := toFlatMap(viewMap["Template"]) + if template != nil { + source := strings.TrimSpace(fmt.Sprint(template["Source"])) + aliases, projection, hasSQL := analyzeSQL(source) + if hasSQL { + meta.HasSQL = true + } + if len(aliases) > 0 { + meta.Aliases = aliases + } + meta.Projection = projection + } + selector := toFlatMap(viewMap["Selector"]) + if selector != nil { + registerAlias(meta.Namespaces, fmt.Sprint(selector["Namespace"])) + } + if meta.Aliases == nil { + meta.Aliases = map[string]bool{} + } + if meta.Namespaces == nil { + meta.Namespaces = map[string]bool{} + } + for alias := range meta.Aliases { + meta.Namespaces[alias] = true + } + if meta.Projection.Columns == nil { + meta.Projection.Columns = map[string]bool{} + } +} + +func collectRelationIssues(viewIndex map[string]*viewMeta) []string { + var issues []string + for _, parent := range viewIndex { + for _, rel := range parent.Relations { + ref := viewIndex[strings.TrimSpace(rel.Ref)] + for i := 0; i < rel.PairCount; i++ { + left, right := linkAt(rel.On, i), linkAt(rel.OfOn, i) + issues = append(issues, validateParentLink(parent, rel, i, left)...) + issues = append(issues, validateRefLink(parent, ref, rel, i, right)...) + } + } + } + return issues +} + +func validateParentLink(parent *viewMeta, rel relationMeta, i int, left *relationLink) []string { + if left == nil { + return []string{fmt.Sprintf("line %d view=%q relation=%q holder=%q link=%d side=parent: missing On link entry", rel.Line, parent.Name, rel.Name, rel.Holder, i)} + } + return validateLink(parent, rel, "parent", i, *left) +} + +func validateRefLink(parent, ref *viewMeta, rel relationMeta, i int, right *relationLink) []string { + if right == nil { + return []string{fmt.Sprintf("line %d view=%q relation=%q holder=%q link=%d side=ref: missing Of.On link entry", rel.Line, parent.Name, rel.Name, rel.Holder, i)} + } + if ref != nil { + return validateLink(ref, rel, "ref", i, *right) + } + if strings.TrimSpace(rel.Ref) == "" { + return nil + } + return []string{fmt.Sprintf("line %d view=%q relation=%q holder=%q link=%d side=ref: referenced view %q not found", right.Line, parent.Name, rel.Name, rel.Holder, i, rel.Ref)} +} + +func linkAt(links []relationLink, i int) *relationLink { + if i < 0 || i >= len(links) { + return nil + } + return &links[i] +} + +func validateLink(view *viewMeta, rel relationMeta, side string, index int, link relationLink) []string { + if view == nil { + return nil + } + line := link.Line + if line == 0 { + line = rel.Line + } + column := strings.TrimSpace(link.Column) + alias := strings.TrimSpace(link.Namespace) + if column == "" { + return []string{fmt.Sprintf("line %d view=%q relation=%q holder=%q link=%d side=%s: empty column", line, view.Name, rel.Name, rel.Holder, index, side)} + } + + var issues []string + columnProjected := true + if view.HasSQL && !view.Projection.HasStar && len(view.Projection.Columns) > 0 { + columnProjected = hasProjectionColumn(view.Projection.Columns, column) + if !columnProjected { + issues = append(issues, fmt.Sprintf("line %d view=%q relation=%q holder=%q link=%d side=%s alias=%q column=%q: column not projected (columns=%v)", line, view.Name, rel.Name, rel.Holder, index, side, alias, column, sortedKeys(view.Projection.Columns))) + } + } + if alias != "" && view.HasSQL && !columnProjected && !view.Namespaces[strings.ToLower(alias)] { + issues = append(issues, fmt.Sprintf("line %d view=%q relation=%q holder=%q link=%d side=%s alias=%q column=%q: alias not present in SQL/selector namespace (namespaces=%v)", line, view.Name, rel.Name, rel.Holder, index, side, alias, column, sortedKeys(view.Namespaces))) + } + return issues +} + +func hasProjectionColumn(columns map[string]bool, column string) bool { + for _, candidate := range projectionCandidates(column) { + if columns[candidate] { + return true + } + } + return false +} + +func projectionCandidates(column string) []string { + column = strings.TrimSpace(column) + if column == "" { + return nil + } + result := []string{normalizedProjectionKey(column)} + if i := strings.LastIndex(column, "."); i != -1 && i+1 < len(column) { + result = append(result, normalizedProjectionKey(column[i+1:])) + } + return result +} + +func normalizedProjectionKey(value string) string { + return strings.ToLower(strings.Trim(value, "`\"' ")) +} + +func extractViews(root map[string]any) []any { + resource := toFlatMap(root["Resource"]) + if resource == nil { + return nil + } + return canonicalSlice(resource["Views"]) +} + +func collectViewMeta(routeYAML []byte) (map[string]*viewMeta, error) { + var rootNode yaml.Node + if err := yaml.Unmarshal(routeYAML, &rootNode); err != nil { + return nil, err + } + return parseViewMetaNodes(&rootNode), nil +} + +func sortedKeys(index map[string]bool) []string { + ret := make([]string, 0, len(index)) + for key := range index { + ret = append(ret, key) + } + sort.Strings(ret) + return ret +} diff --git a/repository/shape/dql/plan/relation_yaml.go b/repository/shape/dql/plan/relation_yaml.go new file mode 100644 index 000000000..8a141a1af --- /dev/null +++ b/repository/shape/dql/plan/relation_yaml.go @@ -0,0 +1,164 @@ +package plan + +import ( + "strings" + + "gopkg.in/yaml.v3" +) + +func parseViewMetaNodes(rootNode *yaml.Node) map[string]*viewMeta { + result := map[string]*viewMeta{} + views := viewsNode(rootNode) + if views == nil || views.Kind != yaml.SequenceNode { + return result + } + for _, item := range views.Content { + meta := parseViewMeta(item) + if meta == nil || strings.TrimSpace(meta.Name) == "" { + continue + } + result[meta.Name] = meta + } + return result +} + +func parseViewMeta(item *yaml.Node) *viewMeta { + viewMap := nodeMapping(item) + if viewMap == nil { + return nil + } + name := strings.TrimSpace(nodeString(mappingValue(viewMap, "Name"))) + if name == "" { + return nil + } + meta := &viewMeta{ + Name: name, + Line: item.Line, + Aliases: map[string]bool{}, + Namespaces: map[string]bool{}, + Projection: projectionMeta{Columns: map[string]bool{}}, + } + parseViewTemplateMeta(viewMap, meta) + parseViewSelectorMeta(viewMap, meta) + parseViewRelationsMeta(viewMap, meta) + return meta +} + +func parseViewTemplateMeta(viewMap map[string]*yaml.Node, meta *viewMeta) { + template := nodeMapping(mappingValue(viewMap, "Template")) + sourceNode := mappingValue(template, "Source") + if sourceNode == nil { + return + } + aliases, projection, hasSQL := analyzeSQL(nodeString(sourceNode)) + meta.HasSQL = hasSQL + if len(aliases) > 0 { + meta.Aliases = aliases + } + if len(projection.Columns) > 0 || projection.HasStar { + meta.Projection = projection + } +} + +func parseViewSelectorMeta(viewMap map[string]*yaml.Node, meta *viewMeta) { + selector := nodeMapping(mappingValue(viewMap, "Selector")) + registerAlias(meta.Namespaces, nodeString(mappingValue(selector, "Namespace"))) + for alias := range meta.Aliases { + meta.Namespaces[alias] = true + } +} + +func parseViewRelationsMeta(viewMap map[string]*yaml.Node, meta *viewMeta) { + with := mappingValue(viewMap, "With") + if with == nil || with.Kind != yaml.SequenceNode { + return + } + for _, relItem := range with.Content { + rel := parseRelationMeta(relItem) + if rel != nil { + meta.Relations = append(meta.Relations, *rel) + } + } +} + +func parseRelationMeta(relItem *yaml.Node) *relationMeta { + relMap := nodeMapping(relItem) + if relMap == nil { + return nil + } + rel := &relationMeta{ + Line: relItem.Line, + Name: nodeString(mappingValue(relMap, "Name")), + Holder: nodeString(mappingValue(relMap, "Holder")), + On: parseLinkNodes(mappingValue(relMap, "On")), + } + ofMap := nodeMapping(mappingValue(relMap, "Of")) + rel.Ref = nodeString(mappingValue(ofMap, "Ref")) + if rel.Ref == "" { + rel.Ref = nodeString(mappingValue(ofMap, "Name")) + } + rel.OfOn = parseLinkNodes(mappingValue(ofMap, "On")) + rel.PairCount = len(rel.On) + if len(rel.OfOn) > rel.PairCount { + rel.PairCount = len(rel.OfOn) + } + return rel +} + +func parseLinkNodes(seq *yaml.Node) []relationLink { + if seq == nil || seq.Kind != yaml.SequenceNode { + return nil + } + ret := make([]relationLink, 0, len(seq.Content)) + for _, item := range seq.Content { + linkMap := nodeMapping(item) + if linkMap == nil { + ret = append(ret, relationLink{Line: item.Line}) + continue + } + ret = append(ret, relationLink{ + Line: item.Line, + Column: nodeString(mappingValue(linkMap, "Column")), + Namespace: nodeString(mappingValue(linkMap, "Namespace")), + }) + } + return ret +} + +func viewsNode(rootNode *yaml.Node) *yaml.Node { + rootMap := nodeMapping(rootNode) + resource := mappingValue(rootMap, "Resource") + resourceMap := nodeMapping(resource) + return mappingValue(resourceMap, "Views") +} + +func nodeMapping(n *yaml.Node) map[string]*yaml.Node { + if n == nil { + return nil + } + if n.Kind == yaml.DocumentNode && len(n.Content) > 0 { + n = n.Content[0] + } + if n.Kind != yaml.MappingNode { + return nil + } + ret := map[string]*yaml.Node{} + for i := 0; i+1 < len(n.Content); i += 2 { + ret[n.Content[i].Value] = n.Content[i+1] + } + return ret +} + +func mappingValue(m map[string]*yaml.Node, key string) *yaml.Node { + if m == nil { + return nil + } + return m[key] +} + +func nodeString(n *yaml.Node) string { + if n == nil { + return "" + } + return strings.TrimSpace(n.Value) +} diff --git a/repository/shape/dql/preprocess/diagnostics.go b/repository/shape/dql/preprocess/diagnostics.go new file mode 100644 index 000000000..8542dfbf8 --- /dev/null +++ b/repository/shape/dql/preprocess/diagnostics.go @@ -0,0 +1,13 @@ +package preprocess + +import dqlshape "github.com/viant/datly/repository/shape/dql/shape" + +func directiveDiagnostic(code, message, hint, text string, offset int) *dqlshape.Diagnostic { + return &dqlshape.Diagnostic{ + Code: code, + Severity: dqlshape.SeverityError, + Message: message, + Hint: hint, + Span: pointSpan(text, offset), + } +} diff --git a/repository/shape/dql/preprocess/directive_parser.go b/repository/shape/dql/preprocess/directive_parser.go new file mode 100644 index 000000000..9a13320cc --- /dev/null +++ b/repository/shape/dql/preprocess/directive_parser.go @@ -0,0 +1,186 @@ +package preprocess + +import "strings" + +type directiveCall struct { + name string + args []string + start int +} + +func scanDollarCalls(input string, names map[string]bool) []directiveCall { + result := make([]directiveCall, 0) + for i := 0; i < len(input); { + if input[i] != '$' || i+1 >= len(input) || !isIdentifierStart(input[i+1]) { + i++ + continue + } + start := i + 1 + i += 2 + for i < len(input) && isIdentifierPart(input[i]) { + i++ + } + name := strings.ToLower(input[start:i]) + if !names[name] { + continue + } + j := skipSpaces(input, i) + if j >= len(input) || input[j] != '(' { + continue + } + body, end, ok := readCallBody(input, j) + if !ok { + continue + } + result = append(result, directiveCall{ + name: name, + args: splitCallArgs(body), + start: start - 1, + }) + i = end + 1 + } + return result +} + +func readCallBody(input string, openParen int) (string, int, bool) { + depth := 0 + quote := byte(0) + for i := openParen; i < len(input); i++ { + ch := input[i] + if quote != 0 { + if ch == '\\' && i+1 < len(input) { + i++ + continue + } + if ch == quote { + quote = 0 + } + continue + } + if ch == '\'' || ch == '"' { + quote = ch + continue + } + if ch == '(' { + depth++ + continue + } + if ch == ')' { + depth-- + if depth == 0 { + return input[openParen+1 : i], i, true + } + } + } + return "", -1, false +} + +func splitCallArgs(input string) []string { + args := make([]string, 0) + current := strings.Builder{} + depth := 0 + quote := byte(0) + for i := 0; i < len(input); i++ { + ch := input[i] + if quote != 0 { + current.WriteByte(ch) + if ch == '\\' && i+1 < len(input) { + i++ + current.WriteByte(input[i]) + continue + } + if ch == quote { + quote = 0 + } + continue + } + if ch == '\'' || ch == '"' { + quote = ch + current.WriteByte(ch) + continue + } + if ch == '(' { + depth++ + current.WriteByte(ch) + continue + } + if ch == ')' { + if depth > 0 { + depth-- + } + current.WriteByte(ch) + continue + } + if ch == ',' && depth == 0 { + args = append(args, strings.TrimSpace(current.String())) + current.Reset() + continue + } + current.WriteByte(ch) + } + if value := strings.TrimSpace(current.String()); value != "" { + args = append(args, value) + } + return args +} + +func skipSpaces(input string, index int) int { + for index < len(input) { + switch input[index] { + case ' ', '\t', '\n', '\r': + index++ + default: + return index + } + } + return index +} + +func skipInlineSpaces(input string, index int) int { + for index < len(input) { + switch input[index] { + case ' ', '\t': + index++ + default: + return index + } + } + return index +} + +func isIdentifierStart(ch byte) bool { + return (ch >= 'a' && ch <= 'z') || (ch >= 'A' && ch <= 'Z') || ch == '_' +} + +func isIdentifierPart(ch byte) bool { + return isIdentifierStart(ch) || (ch >= '0' && ch <= '9') +} + +func parseQuotedLiteral(input string) (string, bool) { + value := strings.TrimSpace(input) + if len(value) < 2 { + return "", false + } + quote := value[0] + if quote != '\'' && quote != '"' { + return "", false + } + if value[len(value)-1] != quote { + return "", false + } + return value[1 : len(value)-1], true +} + +func hasWordFoldAt(input string, pos int, word string) bool { + if pos < 0 || pos+len(word) > len(input) { + return false + } + if !strings.EqualFold(input[pos:pos+len(word)], word) { + return false + } + next := pos + len(word) + if next >= len(input) { + return true + } + return !isIdentifierPart(input[next]) +} diff --git a/repository/shape/dql/preprocess/extract.go b/repository/shape/dql/preprocess/extract.go new file mode 100644 index 000000000..67b761a41 --- /dev/null +++ b/repository/shape/dql/preprocess/extract.go @@ -0,0 +1,121 @@ +package preprocess + +import ( + "strings" + + dqlshape "github.com/viant/datly/repository/shape/dql/shape" + "github.com/viant/datly/repository/shape/typectx" +) + +func extractSQLAndContext(dql string) (string, *typectx.Context, *dqlshape.Directives, []*dqlshape.Diagnostic) { + ctx := &typectx.Context{} + directives := &dqlshape.Directives{} + if dql == "" { + return "", nil, nil, nil + } + mask := make([]bool, len(dql)) + var diagnostics []*dqlshape.Diagnostic + + blocks := extractSetDirectiveBlocks(dql) + for _, block := range blocks { + applyMask(mask, dql, block.start, block.end) + if block.kind != directiveSettings { + continue + } + diagnostics = append(diagnostics, parseSettingsDirectives(block.body, dql, block.start, directives)...) + } + + lines := strings.SplitAfter(dql, "\n") + if len(lines) == 0 { + lines = []string{dql} + } + + offset := 0 + for _, line := range lines { + trimmed := strings.TrimSpace(line) + lineStart := offset + lineEnd := offset + len(line) + if isTypeContextDirectiveLine(trimmed) { + diagnostics = append(diagnostics, parseTypeContextDirective(trimmed, dql, offsetOfFirstNonSpace(line, offset), ctx)...) + applyMask(mask, dql, lineStart, lineEnd) + offset += len(line) + continue + } + if kind := lineDirectiveKind(trimmed); kind != directiveUnknown { + if !hasMasked(mask, lineStart, lineEnd) { + if kind != directiveSettings { + applyMask(mask, dql, lineStart, lineEnd) + offset += len(line) + continue + } + diagnostics = append(diagnostics, parseSettingsDirectives(trimmed, dql, offsetOfFirstNonSpace(line, offset), directives)...) + applyMask(mask, dql, lineStart, lineEnd) + } + offset += len(line) + continue + } + if isDirectiveLine(trimmed) { + applyMask(mask, dql, lineStart, lineEnd) + } + offset += len(line) + } + masked := []byte(dql) + for i := 0; i < len(masked); i++ { + if !mask[i] { + continue + } + if masked[i] == '\n' || masked[i] == '\r' { + continue + } + masked[i] = ' ' + } + return string(masked), ctx, directives, diagnostics +} + +func applyMask(mask []bool, text string, start, end int) { + if start < 0 { + start = 0 + } + if end > len(text) { + end = len(text) + } + if end <= start { + return + } + for i := start; i < end; i++ { + if text[i] == '\n' || text[i] == '\r' { + continue + } + mask[i] = true + } +} + +func hasMasked(mask []bool, start, end int) bool { + if start < 0 { + start = 0 + } + if end > len(mask) { + end = len(mask) + } + if end <= start { + return false + } + for i := start; i < end; i++ { + if mask[i] { + return true + } + } + return false +} + +func offsetOfFirstNonSpace(line string, base int) int { + for i := 0; i < len(line); i++ { + switch line[i] { + case ' ', '\t', '\r', '\n': + continue + default: + return base + i + } + } + return base +} diff --git a/repository/shape/dql/preprocess/legacy_import.go b/repository/shape/dql/preprocess/legacy_import.go new file mode 100644 index 000000000..5218d0dee --- /dev/null +++ b/repository/shape/dql/preprocess/legacy_import.go @@ -0,0 +1,278 @@ +package preprocess + +import ( + "path" + "strings" + + dqldiag "github.com/viant/datly/repository/shape/dql/diag" + dqlshape "github.com/viant/datly/repository/shape/dql/shape" + "github.com/viant/datly/repository/shape/typectx" +) + +type legacyImportRange struct { + start int + end int +} + +type legacyImportBlockSpec struct { + start int + end int + bodyStart int + bodyEnd int +} + +func extractLegacyTypeImports(dql string) ([]typectx.Import, []legacyImportRange, []*dqlshape.Diagnostic) { + if strings.TrimSpace(dql) == "" { + return nil, nil, nil + } + var ( + imports []typectx.Import + ranges []legacyImportRange + diags []*dqlshape.Diagnostic + ) + inBlock := make([]bool, len(dql)) + + blocks := findLegacyImportBlocks(dql) + for _, block := range blocks { + start, end := block.start, block.end + for i := start; i < end && i < len(inBlock); i++ { + inBlock[i] = true + } + ranges = append(ranges, legacyImportRange{start: start, end: end}) + blockBody := dql[block.bodyStart:block.bodyEnd] + items := parseLegacyImportItems(blockBody, block.bodyStart) + if len(items) == 0 { + diags = append(diags, directiveDiagnostic( + dqldiag.CodeDirImport, + "invalid legacy import declaration", + `expected: import "pkg/path.Type" or import ("pkg/path.Type" alias "x")`, + dql, + start, + )) + continue + } + for _, item := range items { + aImport, ok := parseLegacyImportSpec(item.spec, item.alias) + if !ok { + diags = append(diags, directiveDiagnostic( + dqldiag.CodeDirImport, + "invalid legacy import declaration", + `expected import target with type suffix: "pkg/path.Type"`, + dql, + item.offset, + )) + continue + } + imports = append(imports, aImport) + } + } + + offset := 0 + for _, line := range strings.SplitAfter(dql, "\n") { + start := offset + end := start + len(line) + if start >= len(inBlock) || inBlock[start] { + offset = end + continue + } + spec, alias, ok := parseLegacyImportLine(line) + if !ok { + offset = end + continue + } + aImport, ok := parseLegacyImportSpec(spec, alias) + if !ok { + diags = append(diags, directiveDiagnostic( + dqldiag.CodeDirImport, + "invalid legacy import declaration", + `expected import target with type suffix: "pkg/path.Type"`, + dql, + start, + )) + offset = end + continue + } + imports = append(imports, aImport) + ranges = append(ranges, legacyImportRange{start: start, end: end}) + offset = end + } + + return uniqueTypeImports(imports), ranges, diags +} + +func findLegacyImportBlocks(dql string) []legacyImportBlockSpec { + var result []legacyImportBlockSpec + for lineStart := 0; lineStart < len(dql); { + lineEnd := lineStart + for lineEnd < len(dql) && dql[lineEnd] != '\n' { + lineEnd++ + } + + pos := skipInlineSpaces(dql, lineStart) + if hasWordFoldAt(dql, pos, "import") { + pos = skipSpaces(dql, pos+len("import")) + if pos < len(dql) && dql[pos] == '(' { + body, end, ok := readCallBody(dql, pos) + if ok { + result = append(result, legacyImportBlockSpec{ + start: lineStart, + end: end + 1, + bodyStart: pos + 1, + bodyEnd: pos + 1 + len(body), + }) + lineStart = end + 1 + continue + } + } + } + + if lineEnd < len(dql) { + lineStart = lineEnd + 1 + } else { + break + } + } + return result +} + +type legacyImportItem struct { + spec string + alias string + offset int +} + +func parseLegacyImportItems(input string, base int) []legacyImportItem { + var result []legacyImportItem + for i := 0; i < len(input); { + i = skipLegacyImportSeparators(input, i) + if i >= len(input) { + break + } + start := i + spec, end, ok := readQuotedAt(input, i) + if !ok { + i++ + continue + } + i = skipSpaces(input, end) + alias := "" + if hasWordFoldAt(input, i, "alias") { + i = skipSpaces(input, i+len("alias")) + aliasValue, aliasEnd, ok := readQuotedAt(input, i) + if !ok { + i = end + continue + } + alias = aliasValue + i = aliasEnd + } + result = append(result, legacyImportItem{ + spec: strings.TrimSpace(spec), + alias: strings.TrimSpace(alias), + offset: base + start, + }) + } + return result +} + +func parseLegacyImportLine(line string) (spec, alias string, ok bool) { + input := strings.TrimSpace(line) + if input == "" || !hasWordFoldAt(input, 0, "import") { + return "", "", false + } + index := skipSpaces(input, len("import")) + specValue, end, ok := readQuotedAt(input, index) + if !ok { + return "", "", false + } + index = skipSpaces(input, end) + aliasValue := "" + if hasWordFoldAt(input, index, "alias") { + index = skipSpaces(input, index+len("alias")) + value, aliasEnd, ok := readQuotedAt(input, index) + if !ok { + return "", "", false + } + aliasValue = value + index = skipSpaces(input, aliasEnd) + } + if index != len(input) { + return "", "", false + } + return strings.TrimSpace(specValue), strings.TrimSpace(aliasValue), true +} + +func readQuotedAt(input string, index int) (string, int, bool) { + if index < 0 || index >= len(input) { + return "", index, false + } + quote := input[index] + if quote != '\'' && quote != '"' { + return "", index, false + } + for i := index + 1; i < len(input); i++ { + if input[i] == '\\' && i+1 < len(input) { + i++ + continue + } + if input[i] == quote { + return input[index+1 : i], i + 1, true + } + } + return "", index, false +} + +func skipLegacyImportSeparators(input string, index int) int { + for index < len(input) { + switch input[index] { + case ' ', '\t', '\n', '\r', ',', ';': + index++ + default: + return index + } + } + return index +} + +func parseLegacyImportSpec(spec, alias string) (typectx.Import, bool) { + spec = strings.TrimSpace(spec) + if spec == "" { + return typectx.Import{}, false + } + index := strings.LastIndex(spec, ".") + if index <= 0 || index >= len(spec)-1 { + return typectx.Import{}, false + } + pkg := strings.TrimSpace(spec[:index]) + typeName := strings.TrimSpace(spec[index+1:]) + if pkg == "" || typeName == "" { + return typectx.Import{}, false + } + alias = strings.TrimSpace(alias) + if alias == "" { + alias = path.Base(pkg) + } + return typectx.Import{Alias: alias, Package: pkg}, true +} + +func uniqueTypeImports(input []typectx.Import) []typectx.Import { + if len(input) == 0 { + return nil + } + seen := map[string]bool{} + result := make([]typectx.Import, 0, len(input)) + for _, item := range input { + pkg := strings.TrimSpace(item.Package) + if pkg == "" { + continue + } + alias := strings.TrimSpace(item.Alias) + key := strings.ToLower(alias + "|" + pkg) + if seen[key] { + continue + } + seen[key] = true + result = append(result, typectx.Import{Alias: alias, Package: pkg}) + } + return result +} diff --git a/repository/shape/dql/preprocess/mapper.go b/repository/shape/dql/preprocess/mapper.go new file mode 100644 index 000000000..bb3f14f3a --- /dev/null +++ b/repository/shape/dql/preprocess/mapper.go @@ -0,0 +1,181 @@ +package preprocess + +import ( + "sort" + "unicode/utf8" + + dqlshape "github.com/viant/datly/repository/shape/dql/shape" + "github.com/viant/velty" +) + +type Mapper struct { + trimPrefix int + segments []mapSegment + original string +} + +type mapSegment struct { + newStart int + newEnd int + origBase int + linear bool +} + +func (m *Mapper) MapOffset(offset int) int { + if m == nil { + if offset < 0 { + return 0 + } + return offset + } + mapped := offset + m.trimPrefix + if mapped < 0 { + mapped = 0 + } + for _, seg := range m.segments { + if mapped < seg.newStart || mapped > seg.newEnd { + continue + } + if seg.linear { + delta := mapped - seg.newStart + if delta < 0 { + delta = 0 + } + return seg.origBase + delta + } + return seg.origBase + } + if len(m.segments) == 0 { + return mapped + } + last := m.segments[len(m.segments)-1] + if last.linear { + return last.origBase + (last.newEnd - last.newStart) + } + return last.origBase +} + +func (m *Mapper) Position(offset int) dqlshape.Position { + return positionAt(m.original, m.MapOffset(offset)) +} + +func (m *Mapper) Remap(diags []*dqlshape.Diagnostic) { + if m == nil || len(diags) == 0 { + return + } + for _, diag := range diags { + if diag == nil { + continue + } + start := m.Position(diag.Span.Start.Offset) + end := m.Position(diag.Span.End.Offset) + diag.Span.Start = start + diag.Span.End = end + } +} + +func newMapper(srcLen int, patches []velty.Patch, trimPrefix int, original string) *Mapper { + if trimPrefix < 0 { + trimPrefix = 0 + } + ps := append([]velty.Patch{}, patches...) + sort.Slice(ps, func(i, j int) bool { return ps[i].Span.Start < ps[j].Span.Start }) + segments := make([]mapSegment, 0, len(ps)*2+1) + oldPos := 0 + newPos := 0 + for _, p := range ps { + start := p.Span.Start + end := p.Span.End + 1 + if start < oldPos || start < 0 || end < start || end > srcLen { + continue + } + if start > oldPos { + blockLen := start - oldPos + segments = append(segments, mapSegment{ + newStart: newPos, + newEnd: newPos + blockLen, + origBase: oldPos, + linear: true, + }) + oldPos = start + newPos += blockLen + } + replLen := len(p.Replacement) + if replLen > 0 { + segments = append(segments, mapSegment{ + newStart: newPos, + newEnd: newPos + replLen, + origBase: start, + linear: false, + }) + newPos += replLen + } + oldPos = end + } + if oldPos < srcLen { + blockLen := srcLen - oldPos + segments = append(segments, mapSegment{ + newStart: newPos, + newEnd: newPos + blockLen, + origBase: oldPos, + linear: true, + }) + } + return &Mapper{trimPrefix: trimPrefix, segments: segments, original: original} +} + +func pointSpan(text string, offset int) dqlshape.Span { + start := positionAt(text, offset) + end := positionAt(text, nextOffset(text, offset)) + return dqlshape.Span{Start: start, End: end} +} + +// PointSpan returns a single-point span at offset with rune-aware line/char. +func PointSpan(text string, offset int) dqlshape.Span { + return pointSpan(text, offset) +} + +func nextOffset(text string, offset int) int { + if offset < 0 { + return 0 + } + if offset >= len(text) { + return len(text) + } + _, width := utf8.DecodeRuneInString(text[offset:]) + if width <= 0 { + return offset + 1 + } + return offset + width +} + +func positionAt(text string, offset int) dqlshape.Position { + if offset < 0 { + offset = 0 + } + if offset > len(text) { + offset = len(text) + } + line := 1 + char := 1 + index := 0 + for index < offset { + r, width := utf8.DecodeRuneInString(text[index:]) + if width <= 0 { + break + } + index += width + if r == '\n' { + line++ + char = 1 + } else { + char++ + } + } + return dqlshape.Position{Offset: offset, Line: line, Char: char} +} + +// PositionAt returns rune-aware position for byte offset. +func PositionAt(text string, offset int) dqlshape.Position { + return positionAt(text, offset) +} diff --git a/repository/shape/dql/preprocess/preprocess.go b/repository/shape/dql/preprocess/preprocess.go new file mode 100644 index 000000000..9579dfb0c --- /dev/null +++ b/repository/shape/dql/preprocess/preprocess.go @@ -0,0 +1,156 @@ +package preprocess + +import ( + "strings" + + dqlopt "github.com/viant/datly/repository/shape/dql/optimize" + dqlsanitize "github.com/viant/datly/repository/shape/dql/sanitize" + dqlshape "github.com/viant/datly/repository/shape/dql/shape" + "github.com/viant/datly/repository/shape/typectx" +) + +type Result struct { + Original string + DirectSQL string + Optimized string + SQL string + TypeCtx *typectx.Context + Directives *dqlshape.Directives + Mapper *Mapper + Diagnostics []*dqlshape.Diagnostic +} + +// Extract parses directives and returns SQL with directive lines masked to preserve offsets. +func Extract(dql string) (string, *typectx.Context, *dqlshape.Directives, []*dqlshape.Diagnostic) { + sql, ctx, directives, diags := extractSQLAndContext(dql) + return sql, normalizeTypeContext(ctx), normalizeDirectives(directives), diags +} + +func Prepare(dql string) *Result { + ret := &Result{Original: dql} + sql, typeCtx, directives, dirDiags := Extract(dql) + ret.DirectSQL = stripDecorators(sql) + ret.TypeCtx = typeCtx + ret.Directives = directives + ret.Diagnostics = append(ret.Diagnostics, dirDiags...) + if strings.TrimSpace(ret.DirectSQL) == "" { + return ret + } + optimized, optDiags := dqlopt.Rewrite(ret.DirectSQL) + ret.Diagnostics = append(ret.Diagnostics, optDiags...) + ret.Optimized = optimized + sanitized := dqlsanitize.Rewrite(optimized, dqlsanitize.Options{ + Declared: dqlsanitize.Declared(optimized), + }) + ret.SQL = sanitized.SQL + ret.Mapper = newMapper(len(optimized), sanitized.Patches, sanitized.TrimPrefix, dql) + return ret +} + +func stripDecorators(sql string) string { + if strings.TrimSpace(sql) == "" { + return sql + } + lines := strings.Split(sql, "\n") + filtered := make([]string, 0, len(lines)) + for _, line := range lines { + if isStandaloneDecoratorLine(line) { + continue + } + filtered = append(filtered, line) + } + return cleanupLineCommaArtifacts(filtered) +} + +func isStandaloneDecoratorLine(line string) bool { + trimmed := strings.TrimSpace(strings.TrimSuffix(line, ",")) + if trimmed == "" { + return false + } + open := strings.Index(trimmed, "(") + close := strings.LastIndex(trimmed, ")") + if open <= 0 || close <= open { + return false + } + name := strings.ToLower(strings.TrimSpace(trimmed[:open])) + switch name { + case "use_connector", "allow_nulls", "allownulls", "tag", "cast", "required", "cardinality", "set_limit": + return true + default: + return false + } +} + +func cleanupLineCommaArtifacts(lines []string) string { + if len(lines) == 0 { + return "" + } + result := make([]string, 0, len(lines)) + for _, line := range lines { + trimmed := strings.TrimSpace(line) + if len(result) > 0 && strings.HasPrefix(strings.ToLower(trimmed), "from ") { + prev := strings.TrimRight(result[len(result)-1], " \t") + prev = strings.TrimSuffix(prev, ",") + result[len(result)-1] = prev + } + result = append(result, line) + } + return strings.Join(result, "\n") +} + +func normalizeTypeContext(ctx *typectx.Context) *typectx.Context { + if ctx == nil { + return nil + } + if ctx.DefaultPackage == "" && len(ctx.Imports) == 0 { + return nil + } + return ctx +} + +func normalizeDirectives(input *dqlshape.Directives) *dqlshape.Directives { + if input == nil { + return nil + } + ret := &dqlshape.Directives{ + Meta: strings.TrimSpace(input.Meta), + DefaultConnector: strings.TrimSpace(input.DefaultConnector), + JSONMarshalType: strings.TrimSpace(input.JSONMarshalType), + JSONUnmarshalType: strings.TrimSpace(input.JSONUnmarshalType), + XMLUnmarshalType: strings.TrimSpace(input.XMLUnmarshalType), + Format: strings.TrimSpace(input.Format), + DateFormat: strings.TrimSpace(input.DateFormat), + CaseFormat: strings.TrimSpace(input.CaseFormat), + } + if input.Cache != nil { + ret.Cache = &dqlshape.CacheDirective{ + Enabled: input.Cache.Enabled, + TTL: strings.TrimSpace(input.Cache.TTL), + } + } + if input.MCP != nil { + ret.MCP = &dqlshape.MCPDirective{ + Name: strings.TrimSpace(input.MCP.Name), + Description: strings.TrimSpace(input.MCP.Description), + DescriptionPath: strings.TrimSpace(input.MCP.DescriptionPath), + } + } + if input.Route != nil { + normalizedMethods := make([]string, 0, len(input.Route.Methods)) + for _, method := range input.Route.Methods { + if method = strings.TrimSpace(method); method != "" { + normalizedMethods = append(normalizedMethods, method) + } + } + ret.Route = &dqlshape.RouteDirective{ + URI: strings.TrimSpace(input.Route.URI), + Methods: normalizedMethods, + } + } + if ret.Meta == "" && ret.DefaultConnector == "" && ret.Cache == nil && ret.MCP == nil && ret.Route == nil && + ret.JSONMarshalType == "" && ret.JSONUnmarshalType == "" && ret.XMLUnmarshalType == "" && ret.Format == "" && + ret.DateFormat == "" && ret.CaseFormat == "" { + return nil + } + return ret +} diff --git a/repository/shape/dql/preprocess/preprocess_test.go b/repository/shape/dql/preprocess/preprocess_test.go new file mode 100644 index 000000000..337b9ca41 --- /dev/null +++ b/repository/shape/dql/preprocess/preprocess_test.go @@ -0,0 +1,206 @@ +package preprocess + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + dqldiag "github.com/viant/datly/repository/shape/dql/diag" +) + +func TestPrepare_TypeContext(t *testing.T) { + dql := "#package('a/b')\n#import('x','github.com/acme/x')\nSELECT id FROM t" + pre := Prepare(dql) + require.NotNil(t, pre) + require.NotNil(t, pre.TypeCtx) + assert.Equal(t, "a/b", pre.TypeCtx.DefaultPackage) + require.Len(t, pre.TypeCtx.Imports, 1) + assert.Equal(t, "x", pre.TypeCtx.Imports[0].Alias) +} + +func TestPrepare_InvalidDirectiveDiagnostic(t *testing.T) { + dql := "SELECT 1\n#import('x')" + pre := Prepare(dql) + require.NotNil(t, pre) + require.NotEmpty(t, pre.Diagnostics) + assert.Equal(t, dqldiag.CodeDirImport, pre.Diagnostics[0].Code) + assert.Equal(t, 2, pre.Diagnostics[0].Span.Start.Line) + assert.Equal(t, 1, pre.Diagnostics[0].Span.Start.Char) +} + +func TestMapper_MapOffset_WithSanitizeExpansion(t *testing.T) { + dql := "SELECT id FROM ORDERS t WHERE t.id = $Id AND (" + pre := Prepare(dql) + require.NotNil(t, pre.Mapper) + // Syntax error location after sanitize rewrite should map back to original source. + offset := len(pre.SQL) - 1 + pos := pre.Mapper.Position(offset) + assert.Equal(t, 1, pos.Line) + assert.Equal(t, 46, pos.Char) +} + +func TestPrepare_StripsReadDecorators(t *testing.T) { + dql := `SELECT t.*, +use_connector(t, system), +allow_nulls(t) +FROM t` + pre := Prepare(dql) + require.NotNil(t, pre) + assert.NotContains(t, pre.DirectSQL, "use_connector") + assert.NotContains(t, pre.DirectSQL, "allow_nulls") + assert.Contains(t, pre.DirectSQL, "SELECT t.*") + assert.Contains(t, pre.DirectSQL, "FROM t") + assert.NotContains(t, pre.DirectSQL, ",\nFROM") +} + +func TestPrepare_MultilineSetDirective_TypeContext(t *testing.T) { + dql := "#package('a/b')\n#import('x','github.com/acme/x')\nSELECT id FROM t" + pre := Prepare(dql) + require.NotNil(t, pre) + require.NotNil(t, pre.TypeCtx) + assert.Equal(t, "a/b", pre.TypeCtx.DefaultPackage) + require.Len(t, pre.TypeCtx.Imports, 1) + assert.Equal(t, "x", pre.TypeCtx.Imports[0].Alias) + assert.Equal(t, "github.com/acme/x", pre.TypeCtx.Imports[0].Package) + assert.Contains(t, pre.DirectSQL, "SELECT id FROM t") +} + +func TestPrepare_InvalidMultilineImportDiagnostic(t *testing.T) { + dql := "SELECT 1\n#import(\n'x'\n)" + pre := Prepare(dql) + require.NotNil(t, pre) + require.NotEmpty(t, pre.Diagnostics) + assert.Equal(t, dqldiag.CodeDirImport, pre.Diagnostics[0].Code) + assert.Equal(t, 2, pre.Diagnostics[0].Span.Start.Line) + assert.GreaterOrEqual(t, pre.Diagnostics[0].Span.Start.Char, 1) +} + +func TestPrepare_SpecialDirectives(t *testing.T) { + dql := "#settings($_ = $meta('docs/orders.md'))\n" + + "#setting($_ = $connector('analytics'))\n" + + "#settings($_ = $cache(true, '5m'))\n" + + "#settings($_ = $mcp('orders.search', 'Search orders', 'docs/mcp/orders.md'))\n" + + "#settings($_ = $marshal('application/json','pkg.OrderJSON'))\n" + + "#settings($_ = $unmarshal('application/json','pkg.OrderIn'))\n" + + "#settings($_ = $unmarshal('application/xml','pkg.OrderXMLIn'))\n" + + "#settings($_ = $format('tabular_json'))\n" + + "#settings($_ = $date_format('2006-01-02'))\n" + + "#settings($_ = $case_format('lc'))\n" + + "SELECT id FROM ORDERS o" + pre := Prepare(dql) + require.NotNil(t, pre) + require.NotNil(t, pre.Directives) + assert.Equal(t, "docs/orders.md", pre.Directives.Meta) + assert.Equal(t, "analytics", pre.Directives.DefaultConnector) + require.NotNil(t, pre.Directives.Cache) + assert.True(t, pre.Directives.Cache.Enabled) + assert.Equal(t, "5m", pre.Directives.Cache.TTL) + require.NotNil(t, pre.Directives.MCP) + assert.Equal(t, "orders.search", pre.Directives.MCP.Name) + assert.Equal(t, "Search orders", pre.Directives.MCP.Description) + assert.Equal(t, "docs/mcp/orders.md", pre.Directives.MCP.DescriptionPath) + assert.Equal(t, "pkg.OrderJSON", pre.Directives.JSONMarshalType) + assert.Equal(t, "pkg.OrderIn", pre.Directives.JSONUnmarshalType) + assert.Equal(t, "pkg.OrderXMLIn", pre.Directives.XMLUnmarshalType) + assert.Equal(t, "tabular", pre.Directives.Format) + assert.Equal(t, "2006-01-02", pre.Directives.DateFormat) + assert.Equal(t, "lc", pre.Directives.CaseFormat) +} + +func TestPrepare_InvalidSpecialDirectiveDiagnostic(t *testing.T) { + dql := "SELECT 1\n#settings($_ = $mcp())" + pre := Prepare(dql) + require.NotNil(t, pre) + require.NotEmpty(t, pre.Diagnostics) + assert.Equal(t, dqldiag.CodeDirMCP, pre.Diagnostics[0].Code) + assert.Equal(t, 2, pre.Diagnostics[0].Span.Start.Line) +} + +func TestPrepare_InvalidConnectorDirectiveDiagnostic(t *testing.T) { + dql := "SELECT 1\n#settings($_ = $connector())" + pre := Prepare(dql) + require.NotNil(t, pre) + require.NotEmpty(t, pre.Diagnostics) + assert.Equal(t, dqldiag.CodeDirConnector, pre.Diagnostics[0].Code) + assert.Equal(t, 2, pre.Diagnostics[0].Span.Start.Line) +} + +func TestPrepare_RouteDirective(t *testing.T) { + dql := "SELECT 1\n#settings($_ = $route('/v1/api/orders', 'GET', 'POST', 'PATCH'))" + pre := Prepare(dql) + require.NotNil(t, pre) + require.NotNil(t, pre.Directives) + require.NotNil(t, pre.Directives.Route) + assert.Equal(t, "/v1/api/orders", pre.Directives.Route.URI) + assert.Equal(t, []string{"GET", "POST", "PATCH"}, pre.Directives.Route.Methods) +} + +func TestPrepare_InvalidRouteDirectiveDiagnostic(t *testing.T) { + dql := "SELECT 1\n#settings($_ = $route('/v1/api/orders', 'GOT'))" + pre := Prepare(dql) + require.NotNil(t, pre) + require.NotEmpty(t, pre.Diagnostics) + assert.Equal(t, dqldiag.CodeDirRoute, pre.Diagnostics[0].Code) + assert.Equal(t, 2, pre.Diagnostics[0].Span.Start.Line) +} + +func TestPrepare_InvalidCaseFormatDirectiveDiagnostic(t *testing.T) { + dql := "SELECT 1\n#settings($_ = $case_format('unknown'))" + pre := Prepare(dql) + require.NotNil(t, pre) + require.NotEmpty(t, pre.Diagnostics) + assert.Equal(t, dqldiag.CodeDirCaseFormat, pre.Diagnostics[0].Code) +} + +func TestPrepare_DefineDirective_DoesNotDriveSettingsExtraction(t *testing.T) { + dql := "#define($_ = $package('a/b'))\nSELECT 1" + pre := Prepare(dql) + require.NotNil(t, pre) + assert.Nil(t, pre.TypeCtx) +} + +func TestPrepare_PackageImportInSettings_UnsupportedDiagnostic(t *testing.T) { + dql := "#settings($_ = $package('x'))\nSELECT 1" + pre := Prepare(dql) + require.NotNil(t, pre) + require.NotEmpty(t, pre.Diagnostics) + assert.Equal(t, dqldiag.CodeDirUnsupported, pre.Diagnostics[0].Code) + assert.Equal(t, 1, pre.Diagnostics[0].Span.Start.Line) +} + +func TestPrepare_TypeContext_CaseInsensitive(t *testing.T) { + dql := "#Package('a/b')\n#Import('x','github.com/acme/x')\nSELECT id FROM t" + pre := Prepare(dql) + require.NotNil(t, pre) + require.NotNil(t, pre.TypeCtx) + assert.Equal(t, "a/b", pre.TypeCtx.DefaultPackage) + require.Len(t, pre.TypeCtx.Imports, 1) + assert.Equal(t, "x", pre.TypeCtx.Imports[0].Alias) + assert.Equal(t, "github.com/acme/x", pre.TypeCtx.Imports[0].Package) +} + +func TestExtractLegacyTypeImports_BlockAndLine(t *testing.T) { + dql := "import (\n" + + " \"github.com/acme/a.TypeA\"\n" + + " \"github.com/acme/b.TypeB\" alias \"b\"\n" + + ")\n" + + "import \"github.com/acme/c.TypeC\"\n" + + imports, ranges, diags := extractLegacyTypeImports(dql) + require.Empty(t, diags) + require.Len(t, ranges, 2) + require.Len(t, imports, 3) + assert.Equal(t, "a", imports[0].Alias) + assert.Equal(t, "github.com/acme/a", imports[0].Package) + assert.Equal(t, "b", imports[1].Alias) + assert.Equal(t, "github.com/acme/b", imports[1].Package) + assert.Equal(t, "c", imports[2].Alias) + assert.Equal(t, "github.com/acme/c", imports[2].Package) +} + +func TestExtractLegacyTypeImports_InvalidBlockDiagnostic(t *testing.T) { + dql := "import (\n alias \"oops\"\n)\nSELECT 1" + _, _, diags := extractLegacyTypeImports(dql) + require.NotEmpty(t, diags) + assert.Equal(t, dqldiag.CodeDirImport, diags[0].Code) +} diff --git a/repository/shape/dql/preprocess/scanner.go b/repository/shape/dql/preprocess/scanner.go new file mode 100644 index 000000000..3b7513917 --- /dev/null +++ b/repository/shape/dql/preprocess/scanner.go @@ -0,0 +1,137 @@ +package preprocess + +import ( + "strings" + + "github.com/viant/parsly" + "github.com/viant/parsly/matcher" +) + +var ( + ppWhitespaceToken = 1 + ppExprGroupToken = 2 + + ppWhitespaceMatcher = parsly.NewToken(ppWhitespaceToken, "Whitespace", matcher.NewWhiteSpace()) + ppExprGroupMatcher = parsly.NewToken(ppExprGroupToken, "( ... )", matcher.NewBlock('(', ')', '\\')) +) + +type setDirectiveBlock struct { + start int + end int + body string + kind directiveKind +} + +type directiveKind int + +const ( + directiveUnknown directiveKind = iota + directiveSet + directiveDefine + directiveSettings +) + +func isDirectiveLine(line string) bool { + if line == "" { + return false + } + if isTypeContextDirectiveLine(line) { + return true + } + if isSetLine(line) { + return true + } + if strings.HasPrefix(line, "#if(") || strings.HasPrefix(line, "#elseif(") || strings.HasPrefix(line, "#else") || strings.HasPrefix(line, "#end") { + return true + } + return false +} + +func isSetLine(line string) bool { + if line == "" { + return false + } + return lineDirectiveKind(line) != directiveUnknown +} + +func extractSetDirectiveBlocks(dql string) []setDirectiveBlock { + cursor := parsly.NewCursor("", []byte(dql), 0) + var result []setDirectiveBlock + for cursor.Pos < cursor.InputSize { + start := cursor.Pos + kind, keywordLen, ok := matchDirectiveAt(dql, start) + if !ok { + cursor.Pos++ + continue + } + cursor.Pos += keywordLen + group := cursor.MatchAfterOptional(ppWhitespaceMatcher, ppExprGroupMatcher) + if group.Code != ppExprGroupToken { + cursor.Pos = start + 1 + continue + } + groupText := group.Text(cursor) + if len(groupText) < 2 { + continue + } + end := cursor.Pos + result = append(result, setDirectiveBlock{ + start: start, + end: end, + body: groupText[1 : len(groupText)-1], + kind: kind, + }) + } + return result +} + +func lineDirectiveKind(line string) directiveKind { + if line == "" { + return directiveUnknown + } + switch { + case strings.HasPrefix(line, "#settings("), strings.HasPrefix(line, "#settings ("): + return directiveSettings + case strings.HasPrefix(line, "#setting("), strings.HasPrefix(line, "#setting ("): + return directiveSettings + case strings.HasPrefix(line, "#define("), strings.HasPrefix(line, "#define ("): + return directiveDefine + case strings.HasPrefix(line, "#set("), strings.HasPrefix(line, "#set ("): + return directiveSet + default: + return directiveUnknown + } +} + +func matchDirectiveAt(dql string, pos int) (directiveKind, int, bool) { + if pos < 0 || pos >= len(dql) || dql[pos] != '#' { + return directiveUnknown, 0, false + } + remaining := dql[pos:] + switch { + case hasDirectivePrefix(remaining, "#settings"): + return directiveSettings, len("#settings"), true + case hasDirectivePrefix(remaining, "#setting"): + return directiveSettings, len("#setting"), true + case hasDirectivePrefix(remaining, "#define"): + return directiveDefine, len("#define"), true + case hasDirectivePrefix(remaining, "#set"): + return directiveSet, len("#set"), true + default: + return directiveUnknown, 0, false + } +} + +func hasDirectivePrefix(input string, directive string) bool { + if len(input) < len(directive) { + return false + } + if !strings.EqualFold(input[:len(directive)], directive) { + return false + } + if len(input) == len(directive) { + return true + } + next := input[len(directive)] + return next == '(' || next == ' ' || next == '\t' || next == '\r' || next == '\n' +} diff --git a/repository/shape/dql/preprocess/settings_directives.go b/repository/shape/dql/preprocess/settings_directives.go new file mode 100644 index 000000000..3d7793c8f --- /dev/null +++ b/repository/shape/dql/preprocess/settings_directives.go @@ -0,0 +1,435 @@ +package preprocess + +import ( + "net/http" + "strings" + + "github.com/viant/datly/repository/content" + dqldiag "github.com/viant/datly/repository/shape/dql/diag" + dqlshape "github.com/viant/datly/repository/shape/dql/shape" + "github.com/viant/tagly/format/text" +) + +var ( + metaDirectiveName = map[string]bool{"meta": true} + connectorDirectiveName = map[string]bool{"connector": true} + cacheDirectiveName = map[string]bool{"cache": true} + mcpDirectiveName = map[string]bool{"mcp": true} + routeDirectiveName = map[string]bool{"route": true} + marshalDirectiveName = map[string]bool{"marshal": true} + unmarshalDirectiveName = map[string]bool{"unmarshal": true} + formatDirectiveName = map[string]bool{"format": true} + dateFormatDirectiveName = map[string]bool{"date_format": true} + caseFormatDirectiveName = map[string]bool{"case_format": true} +) + +func parseSettingsDirectives(input, fullDQL string, diagnosticOffset int, directives *dqlshape.Directives) []*dqlshape.Diagnostic { + if strings.TrimSpace(input) == "" { + return nil + } + var diagnostics []*dqlshape.Diagnostic + lower := strings.ToLower(input) + if strings.Contains(lower, "$package") || strings.Contains(lower, "$import") { + diagnostics = append(diagnostics, directiveDiagnostic( + dqldiag.CodeDirUnsupported, + "type-context directives are not allowed in #settings", + "use #package('module/path') and #import('alias','github.com/acme/pkg')", + fullDQL, + diagnosticOffset, + )) + } + if strings.Contains(lower, "$meta") { + values := parseMetaDirectives(input) + if len(values) == 0 { + diagnostics = append(diagnostics, directiveDiagnostic(dqldiag.CodeDirMeta, "invalid $meta directive", "expected: #settings($_ = $meta('relative/or/absolute/path'))", fullDQL, diagnosticOffset)) + } else { + directives.Meta = values[len(values)-1] + } + } + if strings.Contains(lower, "$connector") { + values := parseConnectorDirectives(input) + if len(values) == 0 { + diagnostics = append(diagnostics, directiveDiagnostic(dqldiag.CodeDirConnector, "invalid $connector directive", "expected: #settings($_ = $connector('connector_name'))", fullDQL, diagnosticOffset)) + } else { + directives.DefaultConnector = values[len(values)-1] + } + } + if strings.Contains(lower, "$cache") { + values := parseCacheDirectives(input) + if len(values) == 0 { + diagnostics = append(diagnostics, directiveDiagnostic(dqldiag.CodeDirCache, "invalid $cache directive", "expected: #settings($_ = $cache(true, '5m'))", fullDQL, diagnosticOffset)) + } else { + directives.Cache = values[len(values)-1] + } + } + if strings.Contains(lower, "$mcp") { + values := parseMCPDirectives(input) + if len(values) == 0 { + diagnostics = append(diagnostics, directiveDiagnostic(dqldiag.CodeDirMCP, "invalid $mcp directive", "expected: #settings($_ = $mcp('tool.name','description','docs/path.md'))", fullDQL, diagnosticOffset)) + } else { + directives.MCP = values[len(values)-1] + } + } + if strings.Contains(lower, "$route") { + values := parseRouteDirectives(input) + if len(values) == 0 { + diagnostics = append(diagnostics, directiveDiagnostic(dqldiag.CodeDirRoute, "invalid $route directive", "expected: #settings($_ = $route('/v1/api/path','GET','POST'))", fullDQL, diagnosticOffset)) + } else { + directives.Route = values[len(values)-1] + } + } + if strings.Contains(lower, "$marshal") { + values := parseMarshalDirectives(input) + if len(values) == 0 { + diagnostics = append(diagnostics, directiveDiagnostic(dqldiag.CodeDirMarshal, "invalid $marshal directive", "expected: #settings($_ = $marshal('application/json','pkg.Type'))", fullDQL, diagnosticOffset)) + } else { + directives.JSONMarshalType = values[len(values)-1] + } + } + if strings.Contains(lower, "$unmarshal") { + values := parseUnmarshalDirectives(input) + if len(values) == 0 { + diagnostics = append(diagnostics, directiveDiagnostic(dqldiag.CodeDirUnmarshal, "invalid $unmarshal directive", "expected: #settings($_ = $unmarshal('application/json','pkg.Type'))", fullDQL, diagnosticOffset)) + } else { + last := values[len(values)-1] + if last.JSONType != "" { + directives.JSONUnmarshalType = last.JSONType + } + if last.XMLType != "" { + directives.XMLUnmarshalType = last.XMLType + } + } + } + if strings.Contains(lower, "$format") { + values := parseFormatDirectives(input) + if len(values) == 0 { + diagnostics = append(diagnostics, directiveDiagnostic(dqldiag.CodeDirFormat, "invalid $format directive", "expected: #settings($_ = $format('tabular_json'))", fullDQL, diagnosticOffset)) + } else { + directives.Format = values[len(values)-1] + } + } + if strings.Contains(lower, "$date_format") { + values := parseDateFormatDirectives(input) + if len(values) == 0 { + diagnostics = append(diagnostics, directiveDiagnostic(dqldiag.CodeDirDateFormat, "invalid $date_format directive", "expected: #settings($_ = $date_format('2006-01-02'))", fullDQL, diagnosticOffset)) + } else { + directives.DateFormat = values[len(values)-1] + } + } + if strings.Contains(lower, "$case_format") { + values := parseCaseFormatDirectives(input) + if len(values) == 0 { + diagnostics = append(diagnostics, directiveDiagnostic(dqldiag.CodeDirCaseFormat, "invalid $case_format directive", "expected: #settings($_ = $case_format('lc'))", fullDQL, diagnosticOffset)) + } else { + directives.CaseFormat = values[len(values)-1] + } + } + return diagnostics +} + +func parseMetaDirectives(input string) []string { + calls := scanDollarCalls(input, metaDirectiveName) + result := make([]string, 0, len(calls)) + for _, call := range calls { + if len(call.args) != 1 { + continue + } + value, ok := parseQuotedLiteral(call.args[0]) + if !ok { + continue + } + if value = strings.TrimSpace(value); value != "" { + result = append(result, value) + } + } + return result +} + +func parseConnectorDirectives(input string) []string { + calls := scanDollarCalls(input, connectorDirectiveName) + result := make([]string, 0, len(calls)) + for _, call := range calls { + if len(call.args) != 1 { + continue + } + value, ok := parseQuotedLiteral(call.args[0]) + if !ok { + continue + } + if value = strings.TrimSpace(value); value != "" { + result = append(result, value) + } + } + return result +} + +func parseCacheDirectives(input string) []*dqlshape.CacheDirective { + calls := scanDollarCalls(input, cacheDirectiveName) + result := make([]*dqlshape.CacheDirective, 0, len(calls)) + for _, call := range calls { + if len(call.args) == 0 || len(call.args) > 2 { + continue + } + enabledRaw := strings.TrimSpace(call.args[0]) + var enabled bool + switch { + case strings.EqualFold(enabledRaw, "true"): + enabled = true + case strings.EqualFold(enabledRaw, "false"): + enabled = false + default: + continue + } + ttl := "" + if len(call.args) == 2 { + value, ok := parseQuotedLiteral(call.args[1]) + if !ok { + continue + } + ttl = strings.TrimSpace(value) + } + result = append(result, &dqlshape.CacheDirective{Enabled: enabled, TTL: ttl}) + } + return result +} + +func parseMCPDirectives(input string) []*dqlshape.MCPDirective { + calls := scanDollarCalls(input, mcpDirectiveName) + result := make([]*dqlshape.MCPDirective, 0, len(calls)) + for _, call := range calls { + if len(call.args) < 1 || len(call.args) > 3 { + continue + } + name, ok := parseQuotedLiteral(call.args[0]) + if !ok { + continue + } + name = strings.TrimSpace(name) + if name == "" { + continue + } + description := "" + if len(call.args) > 1 { + value, ok := parseQuotedLiteral(call.args[1]) + if !ok { + continue + } + description = strings.TrimSpace(value) + } + descriptionPath := "" + if len(call.args) > 2 { + value, ok := parseQuotedLiteral(call.args[2]) + if !ok { + continue + } + descriptionPath = strings.TrimSpace(value) + } + result = append(result, &dqlshape.MCPDirective{ + Name: name, + Description: description, + DescriptionPath: descriptionPath, + }) + } + return result +} + +func parseRouteDirectives(input string) []*dqlshape.RouteDirective { + calls := scanDollarCalls(input, routeDirectiveName) + result := make([]*dqlshape.RouteDirective, 0, len(calls)) + for _, call := range calls { + if len(call.args) == 0 { + continue + } + uri, ok := parseQuotedLiteral(call.args[0]) + if !ok { + continue + } + uri = strings.TrimSpace(uri) + if !strings.HasPrefix(uri, "/") { + continue + } + methodsRaw := make([]string, 0, len(call.args)-1) + for _, arg := range call.args[1:] { + method, ok := parseQuotedLiteral(arg) + if !ok { + methodsRaw = nil + break + } + methodsRaw = append(methodsRaw, method) + } + if methodsRaw == nil { + continue + } + methods, ok := normalizeHTTPMethods(methodsRaw) + if !ok { + continue + } + result = append(result, &dqlshape.RouteDirective{ + URI: uri, + Methods: methods, + }) + } + return result +} + +func normalizeHTTPMethods(input []string) ([]string, bool) { + if len(input) == 0 { + return nil, true + } + valid := map[string]bool{ + http.MethodGet: true, + http.MethodPost: true, + http.MethodPut: true, + http.MethodPatch: true, + http.MethodDelete: true, + http.MethodHead: true, + http.MethodOptions: true, + http.MethodTrace: true, + http.MethodConnect: true, + } + seen := map[string]bool{} + result := make([]string, 0, len(input)) + for _, item := range input { + method := strings.ToUpper(strings.TrimSpace(item)) + if method == "" { + return nil, false + } + if !valid[method] { + return nil, false + } + if seen[method] { + continue + } + seen[method] = true + result = append(result, method) + } + return result, true +} + +func parseMarshalDirectives(input string) []string { + calls := scanDollarCalls(input, marshalDirectiveName) + result := make([]string, 0, len(calls)) + for _, call := range calls { + if len(call.args) != 2 { + continue + } + mimeType, ok := parseQuotedLiteral(call.args[0]) + if !ok { + continue + } + mimeType = strings.ToLower(strings.TrimSpace(mimeType)) + if mimeType != content.JSONContentType { + continue + } + typeName, ok := parseQuotedLiteral(call.args[1]) + if !ok { + continue + } + if typeName = strings.TrimSpace(typeName); typeName != "" { + result = append(result, typeName) + } + } + return result +} + +type unmarshalDirectiveValue struct { + JSONType string + XMLType string +} + +func parseUnmarshalDirectives(input string) []unmarshalDirectiveValue { + calls := scanDollarCalls(input, unmarshalDirectiveName) + result := make([]unmarshalDirectiveValue, 0, len(calls)) + for _, call := range calls { + if len(call.args) != 2 { + continue + } + mimeType, ok := parseQuotedLiteral(call.args[0]) + if !ok { + continue + } + typeName, ok := parseQuotedLiteral(call.args[1]) + if !ok { + continue + } + mimeType = strings.ToLower(strings.TrimSpace(mimeType)) + typeName = strings.TrimSpace(typeName) + if typeName == "" { + continue + } + value := unmarshalDirectiveValue{} + switch mimeType { + case content.JSONContentType: + value.JSONType = typeName + case content.XMLContentType: + value.XMLType = typeName + default: + continue + } + result = append(result, value) + } + return result +} + +func parseFormatDirectives(input string) []string { + calls := scanDollarCalls(input, formatDirectiveName) + result := make([]string, 0, len(calls)) + for _, call := range calls { + if len(call.args) != 1 { + continue + } + raw, ok := parseQuotedLiteral(call.args[0]) + if !ok { + continue + } + raw = strings.ToLower(strings.TrimSpace(raw)) + switch raw { + case "tabular_json": + result = append(result, content.JSONDataFormatTabular) + case content.JSONFormat, content.XMLFormat, content.CSVFormat, content.JSONDataFormatTabular: + result = append(result, raw) + } + } + return result +} + +func parseDateFormatDirectives(input string) []string { + calls := scanDollarCalls(input, dateFormatDirectiveName) + result := make([]string, 0, len(calls)) + for _, call := range calls { + if len(call.args) != 1 { + continue + } + value, ok := parseQuotedLiteral(call.args[0]) + if !ok { + continue + } + if value = strings.TrimSpace(value); value != "" { + result = append(result, value) + } + } + return result +} + +func parseCaseFormatDirectives(input string) []string { + calls := scanDollarCalls(input, caseFormatDirectiveName) + result := make([]string, 0, len(calls)) + for _, call := range calls { + if len(call.args) != 1 { + continue + } + value, ok := parseQuotedLiteral(call.args[0]) + if !ok { + continue + } + value = strings.TrimSpace(value) + if value == "" { + continue + } + if !text.NewCaseFormat(value).IsDefined() { + continue + } + result = append(result, value) + } + return result +} diff --git a/repository/shape/dql/preprocess/typectx_directives.go b/repository/shape/dql/preprocess/typectx_directives.go new file mode 100644 index 000000000..0307ad07c --- /dev/null +++ b/repository/shape/dql/preprocess/typectx_directives.go @@ -0,0 +1,121 @@ +package preprocess + +import ( + "strings" + + dqldiag "github.com/viant/datly/repository/shape/dql/diag" + dqlshape "github.com/viant/datly/repository/shape/dql/shape" + "github.com/viant/datly/repository/shape/typectx" +) + +func parseTypeContextDirective(line, fullDQL string, offset int, ctx *typectx.Context) []*dqlshape.Diagnostic { + var diagnostics []*dqlshape.Diagnostic + if pkg, ok := parsePackageLineDirective(line); ok { + ctx.DefaultPackage = pkg + return nil + } + if alias, pkg, ok := parseImportLineDirective(line); ok { + ctx.Imports = append(ctx.Imports, typectx.Import{Alias: alias, Package: pkg}) + return nil + } + if strings.HasPrefix(strings.ToLower(strings.TrimSpace(line)), "#package") { + diagnostics = append(diagnostics, directiveDiagnostic( + dqldiag.CodeDirPackage, + "invalid #package directive", + "expected: #package('module/path')", + fullDQL, + offset, + )) + return diagnostics + } + if strings.HasPrefix(strings.ToLower(strings.TrimSpace(line)), "#import") { + diagnostics = append(diagnostics, directiveDiagnostic( + dqldiag.CodeDirImport, + "invalid #import directive", + "expected: #import('alias','github.com/acme/pkg')", + fullDQL, + offset, + )) + } + return diagnostics +} + +func parsePackageLineDirective(line string) (string, bool) { + args, ok := parseExactHashDirectiveCall(line, "package") + if !ok || len(args) != 1 { + return "", false + } + value, ok := parseQuotedLiteral(args[0]) + if !ok { + return "", false + } + value = strings.TrimSpace(value) + if value == "" { + return "", false + } + return value, true +} + +func parseImportLineDirective(line string) (string, string, bool) { + args, ok := parseExactHashDirectiveCall(line, "import") + if !ok || len(args) != 2 { + return "", "", false + } + alias, ok := parseQuotedLiteral(args[0]) + if !ok { + return "", "", false + } + pkg, ok := parseQuotedLiteral(args[1]) + if !ok { + return "", "", false + } + alias = strings.TrimSpace(alias) + pkg = strings.TrimSpace(pkg) + if alias == "" || pkg == "" { + return "", "", false + } + return alias, pkg, true +} + +func isTypeContextDirectiveLine(line string) bool { + line = strings.ToLower(strings.TrimSpace(line)) + switch { + case strings.HasPrefix(line, "#package("), strings.HasPrefix(line, "#package ("): + return true + case strings.HasPrefix(line, "#import("), strings.HasPrefix(line, "#import ("): + return true + default: + return false + } +} + +func parseExactHashDirectiveCall(line, directive string) ([]string, bool) { + input := strings.TrimSpace(line) + if input == "" || input[0] != '#' { + return nil, false + } + index := skipSpaces(input, 1) + start := index + for index < len(input) && isIdentifierPart(input[index]) { + index++ + } + if start == index { + return nil, false + } + if !strings.EqualFold(input[start:index], directive) { + return nil, false + } + index = skipSpaces(input, index) + if index >= len(input) || input[index] != '(' { + return nil, false + } + body, end, ok := readCallBody(input, index) + if !ok { + return nil, false + } + index = skipSpaces(input, end+1) + if index != len(input) { + return nil, false + } + return splitCallArgs(body), true +} diff --git a/repository/shape/dql/render/dql/renderer.go b/repository/shape/dql/render/dql/renderer.go new file mode 100644 index 000000000..9864b4c04 --- /dev/null +++ b/repository/shape/dql/render/dql/renderer.go @@ -0,0 +1,166 @@ +package dql + +import ( + "fmt" + "strings" + + "github.com/viant/datly/repository/shape/dql/ir" +) + +// SourceResolver resolves template SourceURL content when Source is not embedded in IR. +type SourceResolver func(sourceURL string) (string, error) + +type options struct { + rootView string + resolve SourceResolver +} + +// Option configures DQL rendering. +type Option func(*options) + +// WithRootView forces renderer root view selection. +func WithRootView(name string) Option { + return func(o *options) { + o.rootView = strings.TrimSpace(name) + } +} + +// WithSourceResolver configures SourceURL content resolution. +func WithSourceResolver(resolver SourceResolver) Option { + return func(o *options) { + o.resolve = resolver + } +} + +// Encode renders IR document back to DQL/SQL source for the root route view. +func Encode(doc *ir.Document, opts ...Option) ([]byte, error) { + if doc == nil || doc.Root == nil { + return nil, fmt.Errorf("dql render dql: nil IR document") + } + cfg := &options{} + for _, opt := range opts { + if opt != nil { + opt(cfg) + } + } + views := indexViews(doc.Root) + if len(views) == 0 { + return nil, fmt.Errorf("dql render dql: no resource views in IR") + } + rootView := cfg.rootView + if rootView == "" { + rootView = detectRootView(doc.Root) + } + if rootView == "" { + return nil, fmt.Errorf("dql render dql: unable to detect root route view") + } + view := views[rootView] + if view == nil { + return nil, fmt.Errorf("dql render dql: root view %q not found in resources", rootView) + } + sql, err := renderViewSQL(view, cfg) + if err != nil { + return nil, err + } + return []byte(strings.TrimSpace(sql) + "\n"), nil +} + +func renderViewSQL(view map[string]any, cfg *options) (string, error) { + name := stringValue(view["Name"]) + template := mapValue(view["Template"]) + if template == nil { + return "", fmt.Errorf("dql render dql: view %q has no template", name) + } + if source := strings.TrimSpace(stringValue(template["Source"])); source != "" { + return source, nil + } + sourceURL := strings.TrimSpace(stringValue(template["SourceURL"])) + if sourceURL == "" { + return "", fmt.Errorf("dql render dql: view %q has neither template source nor sourceURL", name) + } + source, err := resolveSourceURL(cfg, view, sourceURL) + if err != nil { + return "", fmt.Errorf("dql render dql: view %q resolve %q failed: %w", name, sourceURL, err) + } + source = strings.TrimSpace(source) + if source == "" { + return "", fmt.Errorf("dql render dql: resolved source was empty for %q", sourceURL) + } + return source, nil +} + +func resolveSourceURL(cfg *options, view map[string]any, sourceURL string) (string, error) { + _ = view + if cfg.resolve != nil { + return cfg.resolve(sourceURL) + } + return "", fmt.Errorf("requires SourceURL resolver for %q", sourceURL) +} + +func detectRootView(root map[string]any) string { + for _, routeItem := range sliceValue(root["Routes"]) { + route := mapValue(routeItem) + if route == nil { + continue + } + view := mapValue(route["View"]) + if view == nil { + continue + } + if ref := strings.TrimSpace(stringValue(view["Ref"])); ref != "" { + return ref + } + } + return "" +} + +func indexViews(root map[string]any) map[string]map[string]any { + result := map[string]map[string]any{} + resource := mapValue(root["Resource"]) + if resource == nil { + return result + } + for _, item := range sliceValue(resource["Views"]) { + view := mapValue(item) + if view == nil { + continue + } + name := strings.TrimSpace(stringValue(view["Name"])) + if name == "" { + continue + } + result[name] = view + } + return result +} + +func mapValue(raw any) map[string]any { + if v, ok := raw.(map[string]any); ok { + return v + } + if v, ok := raw.(map[any]any); ok { + out := map[string]any{} + for key, item := range v { + out[fmt.Sprint(key)] = item + } + return out + } + return nil +} + +func sliceValue(raw any) []any { + if items, ok := raw.([]any); ok { + return items + } + return nil +} + +func stringValue(raw any) string { + if raw == nil { + return "" + } + if value, ok := raw.(string); ok { + return value + } + return fmt.Sprint(raw) +} diff --git a/repository/shape/dql/render/dql/renderer_test.go b/repository/shape/dql/render/dql/renderer_test.go new file mode 100644 index 000000000..0989854ee --- /dev/null +++ b/repository/shape/dql/render/dql/renderer_test.go @@ -0,0 +1,117 @@ +package dql + +import ( + "errors" + "testing" + + "github.com/viant/datly/repository/shape/dql/ir" +) + +func TestEncode_WithEmbeddedSource(t *testing.T) { + doc := &ir.Document{Root: map[string]any{ + "Routes": []any{ + map[string]any{ + "View": map[string]any{"Ref": "root"}, + }, + }, + "Resource": map[string]any{ + "Views": []any{ + map[string]any{ + "Name": "root", + "Template": map[string]any{ + "Source": "SELECT * FROM USERS u", + }, + }, + }, + }, + }} + data, err := Encode(doc) + if err != nil { + t.Fatalf("Encode failed: %v", err) + } + if got, want := string(data), "SELECT * FROM USERS u\n"; got != want { + t.Fatalf("unexpected dql, got %q want %q", got, want) + } +} + +func TestEncode_WithSourceURLResolver(t *testing.T) { + doc := &ir.Document{Root: map[string]any{ + "Routes": []any{ + map[string]any{ + "View": map[string]any{"Ref": "root"}, + }, + }, + "Resource": map[string]any{ + "Views": []any{ + map[string]any{ + "Name": "root", + "Template": map[string]any{ + "SourceURL": "queries/root.sql", + }, + }, + }, + }, + }} + data, err := Encode(doc, WithSourceResolver(func(sourceURL string) (string, error) { + if sourceURL != "queries/root.sql" { + t.Fatalf("unexpected sourceURL: %s", sourceURL) + } + return "SELECT 1", nil + })) + if err != nil { + t.Fatalf("Encode failed: %v", err) + } + if got, want := string(data), "SELECT 1\n"; got != want { + t.Fatalf("unexpected dql, got %q want %q", got, want) + } +} + +func TestEncode_SourceURLWithoutResolverFails(t *testing.T) { + doc := &ir.Document{Root: map[string]any{ + "Routes": []any{ + map[string]any{ + "View": map[string]any{"Ref": "root"}, + }, + }, + "Resource": map[string]any{ + "Views": []any{ + map[string]any{ + "Name": "root", + "Template": map[string]any{ + "SourceURL": "queries/root.sql", + }, + }, + }, + }, + }} + _, err := Encode(doc) + if err == nil { + t.Fatalf("expected error") + } +} + +func TestEncode_ResolverError(t *testing.T) { + doc := &ir.Document{Root: map[string]any{ + "Routes": []any{ + map[string]any{ + "View": map[string]any{"Ref": "root"}, + }, + }, + "Resource": map[string]any{ + "Views": []any{ + map[string]any{ + "Name": "root", + "Template": map[string]any{ + "SourceURL": "queries/root.sql", + }, + }, + }, + }, + }} + _, err := Encode(doc, WithSourceResolver(func(sourceURL string) (string, error) { + return "", errors.New("boom") + })) + if err == nil { + t.Fatalf("expected error") + } +} diff --git a/repository/shape/dql/render/yaml/renderer.go b/repository/shape/dql/render/yaml/renderer.go new file mode 100644 index 000000000..83bc39c54 --- /dev/null +++ b/repository/shape/dql/render/yaml/renderer.go @@ -0,0 +1,16 @@ +package yaml + +import ( + "fmt" + + "github.com/viant/datly/repository/shape/dql/ir" + "gopkg.in/yaml.v3" +) + +// Encode renders IR document into YAML bytes. +func Encode(doc *ir.Document) ([]byte, error) { + if doc == nil || doc.Root == nil { + return nil, fmt.Errorf("dql render yaml: nil IR document") + } + return yaml.Marshal(doc.Root) +} diff --git a/repository/shape/dql/sanitize/policy.go b/repository/shape/dql/sanitize/policy.go new file mode 100644 index 000000000..4b6d6b5a4 --- /dev/null +++ b/repository/shape/dql/sanitize/policy.go @@ -0,0 +1,32 @@ +package sanitize + +import "strings" + +type rewritePolicy struct { + declared map[string]bool + consts map[string]bool +} + +func newRewritePolicy(declared, consts map[string]bool) *rewritePolicy { + return &rewritePolicy{ + declared: declared, + consts: consts, + } +} + +func (p *rewritePolicy) rewrite(raw string) string { + holder := holderName(raw) + if holder == "" { + return raw + } + if strings.HasPrefix(raw, "$Unsafe.") || strings.HasPrefix(raw, "${Unsafe.") || strings.HasPrefix(raw, "$Has.") || strings.HasPrefix(raw, "${Has.") { + return raw + } + if p.consts != nil && p.consts[holder] { + return addUnsafePrefix(raw) + } + if p.declared != nil && p.declared[holder] { + return asPlaceholder(raw) + } + return asPlaceholder(addUnsafePrefix(raw)) +} diff --git a/repository/shape/dql/sanitize/policy_test.go b/repository/shape/dql/sanitize/policy_test.go new file mode 100644 index 000000000..c3ceb5f23 --- /dev/null +++ b/repository/shape/dql/sanitize/policy_test.go @@ -0,0 +1,49 @@ +package sanitize + +import "testing" + +func TestRewritePolicy_Rewrite(t *testing.T) { + testCases := []struct { + name string + raw string + declared map[string]bool + consts map[string]bool + expect string + }{ + { + name: "plain selector becomes placeholder + unsafe", + raw: "$ID", + expect: "$criteria.AppendBinding($Unsafe.ID)", + }, + { + name: "unsafe selector preserved", + raw: "$Unsafe.ID", + expect: "$Unsafe.ID", + }, + { + name: "declared selector remains local placeholder", + raw: "$x", + declared: map[string]bool{"x": true}, + expect: "$criteria.AppendBinding($x)", + }, + { + name: "const selector keeps raw unsafe path", + raw: "$ConstID", + consts: map[string]bool{"ConstID": true}, + expect: "$Unsafe.ConstID", + }, + { + name: "function call is untouched", + raw: "$Foo.Bar()", + expect: "$Foo.Bar()", + }, + } + for _, testCase := range testCases { + t.Run(testCase.name, func(t *testing.T) { + policy := newRewritePolicy(testCase.declared, testCase.consts) + if actual := policy.rewrite(testCase.raw); actual != testCase.expect { + t.Fatalf("unexpected rewrite: %s", actual) + } + }) + } +} diff --git a/repository/shape/dql/sanitize/sanitizer.go b/repository/shape/dql/sanitize/sanitizer.go new file mode 100644 index 000000000..ec414620d --- /dev/null +++ b/repository/shape/dql/sanitize/sanitizer.go @@ -0,0 +1,295 @@ +package sanitize + +import ( + "fmt" + "strings" + + "github.com/viant/velty" + "github.com/viant/velty/ast" + aexpr "github.com/viant/velty/ast/expr" +) + +type Options struct { + Declared map[string]bool + Consts map[string]bool +} + +type RewriteResult struct { + SQL string + Patches []velty.Patch + TrimPrefix int +} + +func Declared(input string) map[string]bool { + ret := map[string]bool{} + listener := &declaredListener{declared: ret} + _, _, _ = velty.New(velty.Listener(listener)).Compile([]byte(input)) + for _, name := range scanSetDeclaredHolders(input) { + if name != "" { + ret[name] = true + } + } + return ret +} + +func scanSetDeclaredHolders(input string) []string { + result := make([]string, 0) + lower := strings.ToLower(input) + for i := 0; i < len(input); i++ { + if input[i] != '#' { + continue + } + if !strings.HasPrefix(lower[i:], "#set") { + continue + } + j := i + len("#set") + for j < len(input) && (input[j] == ' ' || input[j] == '\t' || input[j] == '\r' || input[j] == '\n') { + j++ + } + if j >= len(input) || input[j] != '(' { + continue + } + body, end, ok := readSetDirectiveBody(input, j) + if !ok { + continue + } + if name, ok := parseSetDeclaredHolder(body); ok { + result = append(result, name) + } + i = end + } + return result +} + +func parseSetDeclaredHolder(body string) (string, bool) { + text := strings.TrimSpace(body) + if text == "" { + return "", false + } + if !strings.HasPrefix(text, "$_") { + return "", false + } + text = strings.TrimSpace(text[len("$_"):]) + if !strings.HasPrefix(text, "=") { + return "", false + } + text = strings.TrimSpace(text[1:]) + if !strings.HasPrefix(text, "$") || len(text) < 2 { + return "", false + } + name := text[1:] + if !isSanitizeIdentifierStart(name[0]) { + return "", false + } + end := 1 + for end < len(name) && isSanitizeIdentifierPart(name[end]) { + end++ + } + return strings.TrimSpace(name[:end]), true +} + +func readSetDirectiveBody(input string, openParen int) (string, int, bool) { + depth := 0 + quote := byte(0) + for i := openParen; i < len(input); i++ { + ch := input[i] + if quote != 0 { + if ch == '\\' && i+1 < len(input) { + i++ + continue + } + if ch == quote { + quote = 0 + } + continue + } + if ch == '"' || ch == '\'' { + quote = ch + continue + } + if ch == '(' { + depth++ + continue + } + if ch == ')' { + depth-- + if depth == 0 { + return input[openParen+1 : i], i, true + } + } + } + return "", -1, false +} + +func isSanitizeIdentifierStart(ch byte) bool { + return ch == '_' || (ch >= 'a' && ch <= 'z') || (ch >= 'A' && ch <= 'Z') +} + +func isSanitizeIdentifierPart(ch byte) bool { + return isSanitizeIdentifierStart(ch) || (ch >= '0' && ch <= '9') +} + +func SQL(input string, opts Options) string { + return Rewrite(input, opts).SQL +} + +func Rewrite(input string, opts Options) RewriteResult { + if strings.TrimSpace(input) == "" { + return RewriteResult{SQL: strings.TrimSpace(input)} + } + adjuster := &bindingAdjuster{ + source: []byte(input), + declared: opts.Declared, + consts: opts.Consts, + policy: newRewritePolicy(opts.Declared, opts.Consts), + } + out, err := velty.TransformTemplate([]byte(input), adjuster) + if err != nil { + return RewriteResult{SQL: strings.TrimSpace(input)} + } + trimPrefix := leadingTrimWidth(out) + return RewriteResult{ + SQL: strings.TrimSpace(string(out)), + Patches: append([]velty.Patch{}, adjuster.patches...), + TrimPrefix: trimPrefix, + } +} + +type bindingAdjuster struct { + source []byte + declared map[string]bool + consts map[string]bool + policy *rewritePolicy + patches []velty.Patch +} + +func (b *bindingAdjuster) Adjust(node ast.Node, ctx *velty.ParserContext) (velty.Action, error) { + sel, ok := node.(*aexpr.Select) + if !ok { + return velty.Keep(), nil + } + if ctx.CurrentExprContext().Kind == velty.CtxSetLHS { + return velty.Keep(), nil + } + span, ok := ctx.GetSpan(sel) + if !ok { + return velty.Keep(), nil + } + if b.inSetDirective(span.Start) { + return velty.Keep(), nil + } + raw := string(b.source[span.Start : span.End+1]) + replacement := b.rewrite(raw) + if replacement == raw { + return velty.Keep(), nil + } + b.patches = append(b.patches, velty.Patch{ + Span: span, + Replacement: []byte(replacement), + }) + return velty.PatchSpan(span, []byte(replacement)), nil +} + +func (b *bindingAdjuster) inSetDirective(pos int) bool { + if pos <= 0 || pos > len(b.source) { + return false + } + prefix := string(b.source[:pos]) + setPos := strings.LastIndex(prefix, "#set(") + if setPos == -1 { + return false + } + if nl := strings.LastIndex(prefix, "\n"); nl > setPos { + return false + } + segment := prefix[setPos:pos] + return strings.Count(segment, "(") > strings.Count(segment, ")") +} + +func (b *bindingAdjuster) rewrite(raw string) string { + if b.policy == nil { + b.policy = newRewritePolicy(b.declared, b.consts) + } + return b.policy.rewrite(raw) +} + +func holderName(raw string) string { + name := strings.TrimSpace(raw) + if name == "" { + return "" + } + if strings.HasPrefix(name, "${") && strings.HasSuffix(name, "}") { + name = "$" + name[2:len(name)-1] + } + if !strings.HasPrefix(name, "$") { + return "" + } + name = strings.TrimPrefix(name, "$") + if idx := strings.Index(name, "("); idx != -1 { + return "" + } + if idx := strings.Index(name, "."); idx != -1 { + head := name[:idx] + if head == "Unsafe" || head == "Has" { + name = name[idx+1:] + if j := strings.Index(name, "."); j != -1 { + return name[:j] + } + return name + } + return head + } + return name +} + +func addUnsafePrefix(raw string) string { + if strings.HasPrefix(raw, "${") { + return strings.Replace(raw, "${", "${Unsafe.", 1) + } + return strings.Replace(raw, "$", "$Unsafe.", 1) +} + +func asPlaceholder(raw string) string { + if strings.HasPrefix(raw, "${") && strings.HasSuffix(raw, "}") { + inner := "$" + raw[2:len(raw)-1] + return fmt.Sprintf("${criteria.AppendBinding(%s)}", inner) + } + return fmt.Sprintf("$criteria.AppendBinding(%s)", raw) +} + +func leadingTrimWidth(data []byte) int { + i := 0 + for i < len(data) { + switch data[i] { + case ' ', '\t', '\r', '\n': + i++ + default: + return i + } + } + return i +} + +type declaredListener struct { + declared map[string]bool +} + +func (d *declaredListener) OnEvent(e velty.Event) { + if e.Type != velty.EventEnterNode { + return + } + if e.ExprContext.Kind != velty.CtxSetLHS { + return + } + sel, ok := e.Node.(*aexpr.Select) + if !ok { + return + } + name := holderName(sel.FullName) + if name == "" { + name = holderName("$" + sel.ID) + } + if name != "" { + d.declared[name] = true + } +} diff --git a/repository/shape/dql/sanitize/sanitizer_test.go b/repository/shape/dql/sanitize/sanitizer_test.go new file mode 100644 index 000000000..326bd3b1d --- /dev/null +++ b/repository/shape/dql/sanitize/sanitizer_test.go @@ -0,0 +1,245 @@ +package sanitize + +import ( + "fmt" + "strings" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/viant/datly/internal/inference" + legacy "github.com/viant/datly/internal/translator/parser" + vstate "github.com/viant/datly/view/state" + "github.com/viant/velty" + "github.com/viant/velty/ast/expr" +) + +func TestSQL_ParityWithLegacySanitizer(t *testing.T) { + testCases := []struct { + name string + sql string + state inference.State + }{ + { + name: "unsafe binding from plain selector", + sql: "SELECT * FROM t WHERE id = $Id", + }, + { + name: "bracket selector placeholder", + sql: "SELECT * FROM t WHERE id = ${Id}", + }, + { + name: "declared variable in append context", + sql: "#set($x = 1)\nSELECT * FROM t WHERE id = $x", + }, + { + name: "const selector keeps raw unsafe prefix", + sql: "SELECT * FROM t WHERE id = $ConstId", + state: inference.State{ + &inference.Parameter{Parameter: vstate.Parameter{Name: "ConstId", In: vstate.NewConstLocation("ConstId")}}, + }, + }, + } + + for _, testCase := range testCases { + t.Run(testCase.name, func(t *testing.T) { + state := testCase.state + tpl, err := legacy.NewTemplate(testCase.sql, &state) + require.NoError(t, err) + expected := tpl.Sanitize() + + actual := SQL(testCase.sql, Options{ + Declared: tpl.Declared, + Consts: constNames(state), + }) + assert.Equal(t, expected, actual) + }) + } +} + +func TestSQL_ParityWithLegacySanitizer_RuntimeExpansion(t *testing.T) { + testCases := []struct { + name string + sql string + state inference.State + }{ + { + name: "plain selector binding", + sql: "SELECT * FROM t WHERE id = $Id", + }, + { + name: "bracket selector binding", + sql: "SELECT * FROM t WHERE id = ${Id}", + }, + { + name: "declared variable binding", + sql: "#set($x = 7)\nSELECT * FROM t WHERE id = $x", + }, + { + name: "const raw unsafe", + sql: "SELECT * FROM t WHERE id = $ConstId", + state: inference.State{ + &inference.Parameter{Parameter: vstate.Parameter{Name: "ConstId", In: vstate.NewConstLocation("ConstId")}}, + }, + }, + } + + for _, testCase := range testCases { + t.Run(testCase.name, func(t *testing.T) { + state := testCase.state + tpl, err := legacy.NewTemplate(testCase.sql, &state) + require.NoError(t, err) + legacySQL := tpl.Sanitize() + + shapeSQL := SQL(testCase.sql, Options{ + Declared: tpl.Declared, + Consts: constNames(state), + }) + require.Equal(t, legacySQL, shapeSQL) + + assert.Equal(t, renderVeltySQL(t, legacySQL), renderVeltySQL(t, shapeSQL)) + }) + } +} + +func TestHolderName(t *testing.T) { + assert.Equal(t, "Foo", holderName("$Foo")) + assert.Equal(t, "Foo", holderName("${Foo}")) + assert.Equal(t, "Foo", holderName("$Foo.Bar")) + assert.Equal(t, "Foo", holderName("$Unsafe.Foo")) + assert.Equal(t, "Foo", holderName("$Has.Foo")) + assert.Equal(t, "Foo", holderName("$Unsafe.Foo.Bar")) + assert.Equal(t, "", holderName("$Foo.Bar()")) + assert.Equal(t, "", holderName("Foo")) + assert.Equal(t, "", holderName("")) +} + +func TestAddUnsafePrefixAndPlaceholder(t *testing.T) { + assert.Equal(t, "$Unsafe.Foo", addUnsafePrefix("$Foo")) + assert.Equal(t, "${Unsafe.Foo}", addUnsafePrefix("${Foo}")) + assert.Equal(t, "$criteria.AppendBinding($Foo)", asPlaceholder("$Foo")) + assert.Equal(t, "${criteria.AppendBinding($Foo)}", asPlaceholder("${Foo}")) +} + +func TestSQL_EdgeBranches(t *testing.T) { + assert.Equal(t, "", SQL(" ", Options{})) + assert.Equal(t, "#if(", SQL("#if(", Options{})) + assert.Equal(t, "#if(true)", SQL("#if(true)", Options{})) + assert.Equal(t, "SELECT $Unsafe.Id", SQL("SELECT $Unsafe.Id", Options{})) + assert.Equal(t, "SELECT $Has.Id", SQL("SELECT $Has.Id", Options{})) + assert.Equal(t, "SELECT $Foo.Bar()", SQL("SELECT $Foo.Bar()", Options{})) + assert.Equal(t, "#set($x = $y)\nSELECT $criteria.AppendBinding($Unsafe.y)", SQL("#set($x = $y)\nSELECT $y", Options{})) +} + +func TestSQL_RewritePreservesLineCount(t *testing.T) { + input := "#set($x = 1)\nSELECT *\nFROM t\nWHERE id = $Id\nAND name = ${Name}\n" + out := SQL(input, Options{}) + assert.Equal(t, strings.Count(strings.TrimSpace(input), "\n"), strings.Count(out, "\n")) +} + +func TestInSetDirective(t *testing.T) { + adj := &bindingAdjuster{source: []byte("#set($x = $y)\nSELECT $z")} + assert.True(t, adj.inSetDirective(7)) + assert.False(t, adj.inSetDirective(len(adj.source))) + assert.False(t, adj.inSetDirective(-1)) +} + +func TestDeclared(t *testing.T) { + declared := Declared("#set($x = 1)\n#set($y = $x)\nSELECT $x, $z") + assert.True(t, declared["x"]) + assert.True(t, declared["y"]) + assert.False(t, declared["z"]) +} + +func TestDeclared_ParameterDeclarationStyle(t *testing.T) { + declared := Declared("#set($_ = $Jwt(header/Authorization))\nSELECT $Jwt.UserID") + assert.True(t, declared["Jwt"]) +} + +func TestDeclaredListener_OnEventBranches(t *testing.T) { + declared := map[string]bool{} + l := &declaredListener{declared: declared} + + l.OnEvent(velty.Event{Type: velty.EventExitNode}) + l.OnEvent(velty.Event{Type: velty.EventEnterNode, ExprContext: velty.ExprContext{Kind: velty.CtxIfCond}}) + l.OnEvent(velty.Event{Type: velty.EventEnterNode, ExprContext: velty.ExprContext{Kind: velty.CtxSetLHS}, Node: &expr.Literal{Value: "x"}}) + l.OnEvent(velty.Event{Type: velty.EventEnterNode, ExprContext: velty.ExprContext{Kind: velty.CtxSetLHS}, Node: &expr.Select{ID: "x"}}) + + assert.True(t, declared["x"]) +} + +func TestAdjust_Branches(t *testing.T) { + adj := &bindingAdjuster{source: []byte("SELECT $Unsafe.Id")} + + // non selector node + action, err := adj.Adjust(&expr.Literal{Value: "x"}, &velty.ParserContext{}) + require.NoError(t, err) + assert.Equal(t, velty.ActionKeep, action.Kind) + + // selector without span + sel := &expr.Select{FullName: "$Unsafe.Id", ID: "Unsafe"} + action, err = adj.Adjust(sel, &velty.ParserContext{}) + require.NoError(t, err) + assert.Equal(t, velty.ActionKeep, action.Kind) + + // selector in set lhs context + ctx := &velty.ParserContext{} + ctx.InitSource("", adj.source) + ctx.SetSpan(sel, velty.Span{Start: 7, End: 16}) + ctx.PushExprContext(velty.ExprContext{Kind: velty.CtxSetLHS, ArgIdx: -1}) + action, err = adj.Adjust(sel, ctx) + require.NoError(t, err) + assert.Equal(t, velty.ActionKeep, action.Kind) + + // selector replacement equals raw + ctx.PopExprContext() + action, err = adj.Adjust(sel, ctx) + require.NoError(t, err) + assert.Equal(t, velty.ActionKeep, action.Kind) +} + +func constNames(state inference.State) map[string]bool { + ret := map[string]bool{} + for _, param := range state { + if param == nil || param.In == nil { + continue + } + if param.In.Kind == vstate.KindConst { + ret[param.Name] = true + } + } + return ret +} + +type criteriaMock struct{} + +func (c criteriaMock) AppendBinding(value interface{}) string { + return fmt.Sprintf("{%v}", value) +} + +type unsafeMock struct { + Id int + Name string + ConstId int +} + +func renderVeltySQL(t *testing.T, template string) string { + t.Helper() + planner := velty.New() + require.NoError(t, planner.DefineVariable("criteria", criteriaMock{})) + require.NoError(t, planner.DefineVariable("Unsafe", unsafeMock{})) + require.NoError(t, planner.DefineVariable("Id", 0)) + require.NoError(t, planner.DefineVariable("Name", "")) + require.NoError(t, planner.DefineVariable("ConstId", 0)) + + exec, newState, err := planner.Compile([]byte(template)) + require.NoError(t, err) + state := newState() + require.NoError(t, state.SetValue("criteria", criteriaMock{})) + require.NoError(t, state.SetValue("Unsafe", unsafeMock{Id: 10, Name: "ann", ConstId: 77})) + require.NoError(t, state.SetValue("Id", 10)) + require.NoError(t, state.SetValue("Name", "ann")) + require.NoError(t, state.SetValue("ConstId", 77)) + require.NoError(t, exec.Exec(state)) + return state.Buffer.String() +} diff --git a/repository/shape/dql/scan/scanner.go b/repository/shape/dql/scan/scanner.go new file mode 100644 index 000000000..b7ecb2d69 --- /dev/null +++ b/repository/shape/dql/scan/scanner.go @@ -0,0 +1,431 @@ +package scan + +import ( + "context" + "fmt" + "path/filepath" + "reflect" + "strings" + "time" + + _ "github.com/go-sql-driver/mysql" + "github.com/viant/afs" + "github.com/viant/afs/file" + "github.com/viant/afs/url" + "github.com/viant/datly/cmd/options" + "github.com/viant/datly/internal/translator" + "github.com/viant/datly/repository/shape/dql/decl" + "github.com/viant/datly/repository/shape/dql/ir" + "github.com/viant/datly/repository/shape/dql/parse" + dqlplan "github.com/viant/datly/repository/shape/dql/plan" + "github.com/viant/datly/repository/shape/dql/sanitize" + dqlshape "github.com/viant/datly/repository/shape/dql/shape" + "github.com/viant/datly/repository/shape/typectx" + "github.com/viant/datly/repository/shape/typectx/source" + _ "github.com/viant/sqlx/metadata/product/mysql" + "github.com/viant/x" +) + +// Request defines input for DQL scan. +type Request struct { + DQLURL string + ConfigURL string + Repository string + ModulePrefix string + APIPrefix string + Connectors []string + AllowedProvenanceKinds []string + AllowedSourceRoots []string + UseGoModuleResolve *bool + UseGOPATHFallback *bool + StrictProvenance *bool +} + +// Result holds scanner output. +type Result struct { + RuleName string + Shape *dqlshape.Document + IR *ir.Document +} + +// Scanner translates DQL to Datly route YAML in-memory. +type Scanner struct { + fs afs.Service +} + +func New() *Scanner { + return &Scanner{fs: afs.New()} +} + +func (s *Scanner) Scan(ctx context.Context, req *Request) (result *Result, err error) { + defer func() { + if r := recover(); r != nil { + err = fmt.Errorf("dql scan panic: %v", r) + result = nil + } + }() + if req == nil || req.DQLURL == "" { + return nil, fmt.Errorf("dql scan: DQLURL was empty") + } + sourceURL := req.DQLURL + project := inferProject(req.DQLURL) + translate := &options.Translate{} + translate.Rule.Project = project + translate.Rule.Source = []string{sourceURL} + translate.Rule.ModulePrefix = req.ModulePrefix + translate.Repository.RepositoryURL = req.Repository + translate.Repository.APIPrefix = req.APIPrefix + if len(req.Connectors) > 0 { + translate.Repository.Connectors = append(translate.Repository.Connectors, req.Connectors...) + } + if req.ConfigURL != "" { + translate.Repository.Configs.Append(req.ConfigURL) + } + var initErr error + if initErr = translate.Init(ctx); initErr != nil { + return nil, initErr + } + if req.ConfigURL == "" { + // Force in-memory translator config to avoid stale absolute paths from discovered config.json. + translate.Repository.Configs = nil + } + if translate.Rule.ModulePrefix == "" { + translate.Rule.ModulePrefix = "platform" + } + + svc := translator.New(translator.NewConfig(&translate.Repository), s.fs) + if initErr := svc.Init(ctx); initErr != nil { + return nil, initErr + } + if initErr := svc.InitSignature(ctx, &translate.Rule); initErr != nil { + return nil, initErr + } + dsql, loadErr := translate.Rule.LoadSource(ctx, s.fs, translate.Rule.SourceURL()) + if loadErr != nil { + return nil, loadErr + } + translate.Rule.NormalizeComponent(&dsql) + dsql = sanitize.SQL(dsql, sanitize.Options{Declared: sanitize.Declared(dsql)}) + top := &options.Options{Translate: translate} + if initErr = svc.Translate(ctx, &translate.Rule, dsql, top); initErr != nil { + return nil, initErr + } + ruleName := svc.Repository.RuleName(&translate.Rule) + targetSuffix := "/" + ruleName + ".yaml" + for _, item := range svc.Repository.Files { + if !strings.HasSuffix(item.URL, targetSuffix) { + continue + } + if strings.Contains(item.URL, "/.meta/") { + continue + } + return s.result(ruleName, []byte(item.Content), dsql, req) + } + for _, item := range svc.Repository.Files { + if strings.HasSuffix(item.URL, targetSuffix) { + return s.result(ruleName, []byte(item.Content), dsql, req) + } + } + return nil, fmt.Errorf("dql scan: generated YAML not found for %s", ruleName) +} + +func (s *Scanner) result(ruleName string, routeYAML []byte, dql string, req *Request) (*Result, error) { + if err := dqlplan.ValidateRelations(routeYAML); err != nil { + return nil, fmt.Errorf("dql scan relation validation failed (%s): %w", ruleName, err) + } + fromYAML, err := ir.FromYAML(routeYAML) + if err != nil { + return nil, err + } + shapeDoc, err := dqlshape.FromIR(fromYAML) + if err != nil { + return nil, err + } + if parsed, parseErr := parse.New().Parse(dql); parseErr == nil && parsed != nil && parsed.TypeContext != nil { + shapeDoc.TypeContext = parsed.TypeContext + } + if declarations, declErr := decl.Parse(dql); declErr == nil && len(declarations) > 0 { + if resolutions, resolveErr := resolveTypeProvenance(declarations, shapeDoc.TypeContext, fromYAML, req); resolveErr != nil { + return nil, resolveErr + } else { + shapeDoc.TypeResolutions = resolutions + } + } + rebuiltIR, err := dqlshape.ToIR(shapeDoc) + if err != nil { + return nil, err + } + return &Result{RuleName: ruleName, Shape: shapeDoc, IR: rebuiltIR}, nil +} + +func resolveTypeProvenance(declarations []*decl.Declaration, ctx *typectx.Context, doc *ir.Document, req *Request) ([]typectx.Resolution, error) { + if len(declarations) == 0 { + return nil, nil + } + registry, provenance := registryFromIR(doc) + resolver := typectx.NewResolverWithProvenance(registry, ctx, provenance) + policy := newProvenancePolicy(req) + srcResolver, srcErr := newSourceResolver(policy, req) + if srcErr != nil && policy.Strict { + return nil, srcErr + } + var result []typectx.Resolution + for _, declaration := range declarations { + if declaration == nil || declaration.Kind != decl.KindCast { + continue + } + expression := strings.TrimSpace(declaration.DataType) + if expression == "" { + continue + } + resolution, err := resolver.ResolveWithProvenance(expression) + if err != nil { + return nil, fmt.Errorf("dql scan cast resolution failed for %q: %w", expression, err) + } + if resolution == nil { + continue + } + resolution.Target = declaration.Target + enrichResolutionWithAST(resolution, srcResolver) + if issue := validateResolutionPolicy(*resolution, policy); issue != "" { + if policy.Strict { + return nil, fmt.Errorf("dql scan provenance policy failed: %s", issue) + } + resolution.Provenance.Kind = "policy_warn:" + issue + } + result = append(result, *resolution) + } + return result, nil +} + +func registryFromIR(doc *ir.Document) (*x.Registry, map[string]typectx.Provenance) { + registry := x.NewRegistry() + provenance := map[string]typectx.Provenance{} + registerBuiltin := func(rType reflect.Type, kind string) { + aType := x.NewType(rType) + registry.Register(aType) + provenance[aType.Key()] = typectx.Provenance{ + Package: packageOfKey(aType.Key()), + Kind: kind, + } + } + registerBuiltin(reflect.TypeOf(time.Time{}), "builtin") + registerBuiltin(reflect.TypeOf(""), "builtin") + registerBuiltin(reflect.TypeOf(0), "builtin") + registerBuiltin(reflect.TypeOf(int64(0)), "builtin") + registerBuiltin(reflect.TypeOf(float64(0)), "builtin") + registerBuiltin(reflect.TypeOf(true), "builtin") + + if doc == nil || doc.Root == nil { + return registry, provenance + } + resource := asMap(doc.Root["Resource"]) + if resource == nil { + return registry, provenance + } + for _, item := range asSlice(resource["Types"]) { + typeMap := asMap(item) + if typeMap == nil { + continue + } + name := strings.TrimSpace(asString(typeMap["Name"])) + if name == "" { + continue + } + pkg := strings.TrimSpace(asString(typeMap["Package"])) + aType := &x.Type{Name: name, PkgPath: pkg} + registry.Register(aType) + key := aType.Key() + provenance[key] = typectx.Provenance{ + Package: pkg, + File: firstNonEmpty(asString(typeMap["SourceURL"]), asString(typeMap["ModulePath"])), + Kind: "resource_type", + } + } + return registry, provenance +} + +func asMap(raw any) map[string]any { + if value, ok := raw.(map[string]any); ok { + return value + } + if value, ok := raw.(map[any]any); ok { + result := make(map[string]any, len(value)) + for k, v := range value { + result[fmt.Sprint(k)] = v + } + return result + } + return nil +} + +func asSlice(raw any) []any { + if value, ok := raw.([]any); ok { + return value + } + return nil +} + +func asString(raw any) string { + if raw == nil { + return "" + } + if value, ok := raw.(string); ok { + return value + } + return fmt.Sprint(raw) +} + +func firstNonEmpty(values ...string) string { + for _, value := range values { + value = strings.TrimSpace(value) + if value != "" { + return value + } + } + return "" +} + +func packageOfKey(key string) string { + index := strings.LastIndex(key, ".") + if index == -1 { + return "" + } + return key[:index] +} + +type provenancePolicy struct { + AllowedKinds map[string]bool + Roots []string + Strict bool +} + +func newProvenancePolicy(req *Request) provenancePolicy { + allowedKinds := map[string]bool{ + "builtin": true, + "resource_type": true, + "ast_type": true, + } + if req != nil && len(req.AllowedProvenanceKinds) > 0 { + allowedKinds = map[string]bool{} + for _, item := range req.AllowedProvenanceKinds { + item = strings.TrimSpace(strings.ToLower(item)) + if item != "" { + allowedKinds[item] = true + } + } + } + repo := "" + if req != nil { + repo = req.Repository + } + roots := source.NormalizeRoots(repo, requestRoots(req)) + return provenancePolicy{ + AllowedKinds: allowedKinds, + Roots: roots, + Strict: requestStrict(req), + } +} + +func requestRoots(req *Request) []string { + if req == nil { + return nil + } + return req.AllowedSourceRoots +} + +func requestStrict(req *Request) bool { + if req == nil || req.StrictProvenance == nil { + return true + } + return *req.StrictProvenance +} + +func requestUseModule(req *Request) bool { + if req == nil || req.UseGoModuleResolve == nil { + return true + } + return *req.UseGoModuleResolve +} + +func requestUseGOPATH(req *Request) bool { + if req == nil || req.UseGOPATHFallback == nil { + return true + } + return *req.UseGOPATHFallback +} + +func newSourceResolver(policy provenancePolicy, req *Request) (*source.Resolver, error) { + if req == nil || strings.TrimSpace(req.Repository) == "" { + return nil, nil + } + return source.New(source.Config{ + ProjectDir: req.Repository, + AllowedSourceRoots: policy.Roots, + UseGoModuleResolve: requestUseModule(req), + UseGOPATHFallback: requestUseGOPATH(req), + }) +} + +func enrichResolutionWithAST(resolution *typectx.Resolution, srcResolver *source.Resolver) { + if resolution == nil || srcResolver == nil { + return + } + if strings.TrimSpace(resolution.Provenance.File) != "" { + return + } + pkg := strings.TrimSpace(resolution.Provenance.Package) + typeName := typeNameFromKey(resolution.ResolvedKey) + if pkg == "" || typeName == "" { + return + } + filePath, err := srcResolver.ResolveTypeFile(pkg, typeName) + if err != nil { + return + } + resolution.Provenance.File = filePath + if resolution.Provenance.Kind == "" || resolution.Provenance.Kind == "registry" { + resolution.Provenance.Kind = "ast_type" + } +} + +func typeNameFromKey(key string) string { + index := strings.LastIndex(key, ".") + if index == -1 || index+1 >= len(key) { + return "" + } + return key[index+1:] +} + +func validateResolutionPolicy(resolution typectx.Resolution, policy provenancePolicy) string { + kind := strings.TrimSpace(strings.ToLower(resolution.Provenance.Kind)) + if kind == "" { + kind = "registry" + } + if !policy.AllowedKinds[kind] { + return fmt.Sprintf("expression=%q kind=%q not allowed", resolution.Expression, resolution.Provenance.Kind) + } + filePath := strings.TrimSpace(resolution.Provenance.File) + if filePath == "" { + return "" + } + if len(policy.Roots) == 0 { + return "" + } + ok, err := source.IsWithinAnyRoot(filePath, policy.Roots) + if err != nil { + return fmt.Sprintf("expression=%q source=%q invalid: %v", resolution.Expression, filePath, err) + } + if !ok { + return fmt.Sprintf("expression=%q source=%q outside trusted roots", resolution.Expression, filePath) + } + return "" +} + +func inferProject(dqlURL string) string { + base, _ := url.Split(dqlURL, file.Scheme) + if idx := strings.Index(base, "/dql/"); idx != -1 { + return filepath.Clean(base[:idx]) + } + return filepath.Clean(base) +} diff --git a/repository/shape/dql/scan/scanner_test.go b/repository/shape/dql/scan/scanner_test.go new file mode 100644 index 000000000..50e90999a --- /dev/null +++ b/repository/shape/dql/scan/scanner_test.go @@ -0,0 +1,164 @@ +package scan + +import ( + "path/filepath" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestScanner_Result_ValidatesRelations(t *testing.T) { + s := New() + invalidYAML := []byte(` +Resource: + Views: + - Name: Parent + Template: + Source: SELECT p.ID FROM T p + With: + - Name: rel + Holder: Rel + Cardinality: One + On: + - Column: MISSING_COL + Namespace: p + Of: + Ref: Child + On: + - Column: ID + Namespace: c + - Name: Child + Template: + Source: SELECT c.ID FROM T2 c +`) + _, err := s.result("x", invalidYAML, "", nil) + require.Error(t, err) + require.Contains(t, err.Error(), "dql scan relation validation failed") + require.Contains(t, err.Error(), "column=\"MISSING_COL\"") +} + +func TestScanner_Result_BuildsShapeAndIR(t *testing.T) { + s := New() + validYAML := []byte(` +Routes: + - Name: Sample + URI: /sample + Method: GET + View: + Ref: root +Resource: + Views: + - Name: root + Connector: + Ref: main + Template: + Source: SELECT r.ID FROM ROOT r +`) + result, err := s.result("sample", validYAML, "", nil) + require.NoError(t, err) + require.NotNil(t, result) + require.NotNil(t, result.Shape) + require.NotNil(t, result.IR) + require.Equal(t, "root", result.Shape.Routes[0].ViewRef) +} + +func TestScanner_Result_PropagatesTypeContextFromDQL(t *testing.T) { + s := New() + validYAML := []byte(` +Routes: + - Name: Sample + URI: /sample + Method: GET + View: + Ref: root +Resource: + Views: + - Name: root + Connector: + Ref: main + Template: + Source: SELECT r.ID FROM ROOT r +`) + dql := ` +#package('mdp/performance') +#import('perf', 'github.com/acme/mdp/performance') +SELECT r.ID FROM ROOT r` + result, err := s.result("sample", validYAML, dql, nil) + require.NoError(t, err) + require.NotNil(t, result) + require.NotNil(t, result.Shape) + require.NotNil(t, result.Shape.TypeContext) + require.Equal(t, "mdp/performance", result.Shape.TypeContext.DefaultPackage) + require.Len(t, result.Shape.TypeContext.Imports, 1) + require.Equal(t, "perf", result.Shape.TypeContext.Imports[0].Alias) +} + +func TestScanner_Result_ResolvesTypeProvenance(t *testing.T) { + s := New() + validYAML := []byte(` +Routes: + - Name: Sample + URI: /sample + Method: GET + View: + Ref: root +Resource: + Types: + - Name: Order + Package: github.com/acme/mdp/performance + SourceURL: /repo/mdp/performance/order.go + Views: + - Name: root + Connector: + Ref: main + Template: + Source: SELECT r.ID FROM ROOT r +`) + dql := ` +#package('github.com/acme/mdp/performance') +SELECT cast(r.ID as 'Order') FROM ROOT r` + result, err := s.result("sample", validYAML, dql, nil) + require.NoError(t, err) + require.NotNil(t, result) + require.NotNil(t, result.Shape) + require.Len(t, result.Shape.TypeResolutions, 1) + resolution := result.Shape.TypeResolutions[0] + require.Equal(t, "Order", resolution.Expression) + require.Equal(t, "github.com/acme/mdp/performance.Order", resolution.ResolvedKey) + require.Contains(t, []string{"default_package", "global_unique"}, resolution.MatchKind) + require.Equal(t, "resource_type", resolution.Provenance.Kind) + require.Equal(t, "/repo/mdp/performance/order.go", resolution.Provenance.File) +} + +func TestScanner_Result_StrictProvenanceBlocksOutsideRoot(t *testing.T) { + s := New() + validYAML := []byte(` +Routes: + - Name: Sample + URI: /sample + Method: GET + View: + Ref: root +Resource: + Types: + - Name: Order + Package: github.com/acme/mdp/performance + SourceURL: /outside/order.go + Views: + - Name: root + Connector: + Ref: main + Template: + Source: SELECT r.ID FROM ROOT r +`) + dql := ` +#package('github.com/acme/mdp/performance') +SELECT cast(r.ID as 'Order') FROM ROOT r` + strict := true + _, err := s.result("sample", validYAML, dql, &Request{ + Repository: filepath.Clean(t.TempDir()), + StrictProvenance: &strict, + }) + require.Error(t, err) + require.Contains(t, err.Error(), "provenance policy failed") +} diff --git a/repository/shape/dql/shape/convert.go b/repository/shape/dql/shape/convert.go new file mode 100644 index 000000000..00a88eff9 --- /dev/null +++ b/repository/shape/dql/shape/convert.go @@ -0,0 +1,235 @@ +package shape + +import ( + "fmt" + + "github.com/viant/datly/repository/shape/dql/ir" + "github.com/viant/datly/repository/shape/typectx" +) + +// FromIR builds typed shape document from IR. +func FromIR(doc *ir.Document) (*Document, error) { + if doc == nil || doc.Root == nil { + return nil, fmt.Errorf("dql shape: nil IR document") + } + root, ok := deepClone(doc.Root).(map[string]any) + if !ok || root == nil { + return nil, fmt.Errorf("dql shape: invalid IR root") + } + ret := &Document{Root: root} + ret.TypeContext = typeContextFromRoot(root) + ret.TypeResolutions = typeResolutionsFromRoot(root) + for _, item := range asSlice(root["Routes"]) { + routeMap := asMap(item) + if routeMap == nil { + continue + } + route := &Route{ + Name: asString(routeMap["Name"]), + URI: asString(routeMap["URI"]), + Method: asString(routeMap["Method"]), + Description: asString(routeMap["Description"]), + } + if view := asMap(routeMap["View"]); view != nil { + route.ViewRef = asString(view["Ref"]) + } + ret.Routes = append(ret.Routes, route) + } + resourceMap := asMap(root["Resource"]) + if resourceMap != nil { + resource := &Resource{} + for _, item := range asSlice(resourceMap["Views"]) { + viewMap := asMap(item) + if viewMap == nil { + continue + } + view := &View{ + Name: asString(viewMap["Name"]), + Table: asString(viewMap["Table"]), + Module: asString(viewMap["Module"]), + } + if connector := asMap(viewMap["Connector"]); connector != nil { + view.ConnectorRef = asString(connector["Ref"]) + } + resource.Views = append(resource.Views, view) + } + ret.Resource = resource + } + return ret, nil +} + +// ToIR converts shape document back to IR. +func ToIR(doc *Document) (*ir.Document, error) { + if doc == nil || doc.Root == nil { + return nil, fmt.Errorf("dql shape: nil document") + } + root, ok := deepClone(doc.Root).(map[string]any) + if !ok || root == nil { + return nil, fmt.Errorf("dql shape: invalid root") + } + if doc.TypeContext != nil { + root["TypeContext"] = map[string]any{ + "DefaultPackage": doc.TypeContext.DefaultPackage, + "Imports": importsToAny(doc.TypeContext.Imports), + } + } + if len(doc.TypeResolutions) > 0 { + root["TypeResolutions"] = typeResolutionsToAny(doc.TypeResolutions) + } + return &ir.Document{Root: root}, nil +} + +func typeContextFromRoot(root map[string]any) *typectx.Context { + raw := asMap(root["TypeContext"]) + if raw == nil { + return nil + } + ret := &typectx.Context{DefaultPackage: asString(raw["DefaultPackage"])} + for _, item := range asSlice(raw["Imports"]) { + importMap := asMap(item) + if importMap == nil { + continue + } + pkg := asString(importMap["Package"]) + if pkg == "" { + continue + } + ret.Imports = append(ret.Imports, typectx.Import{ + Alias: asString(importMap["Alias"]), + Package: pkg, + }) + } + if ret.DefaultPackage == "" && len(ret.Imports) == 0 { + return nil + } + return ret +} + +func importsToAny(imports []typectx.Import) []any { + if len(imports) == 0 { + return nil + } + result := make([]any, 0, len(imports)) + for _, item := range imports { + if item.Package == "" { + continue + } + result = append(result, map[string]any{ + "Alias": item.Alias, + "Package": item.Package, + }) + } + return result +} + +func typeResolutionsFromRoot(root map[string]any) []typectx.Resolution { + items := asSlice(root["TypeResolutions"]) + if len(items) == 0 { + return nil + } + result := make([]typectx.Resolution, 0, len(items)) + for _, item := range items { + resolutionMap := asMap(item) + if resolutionMap == nil { + continue + } + resolution := typectx.Resolution{ + Expression: asString(resolutionMap["Expression"]), + Target: asString(resolutionMap["Target"]), + ResolvedKey: asString(resolutionMap["ResolvedKey"]), + MatchKind: asString(resolutionMap["MatchKind"]), + } + if provenanceMap := asMap(resolutionMap["Provenance"]); provenanceMap != nil { + resolution.Provenance = typectx.Provenance{ + Package: asString(provenanceMap["Package"]), + File: asString(provenanceMap["File"]), + Kind: asString(provenanceMap["Kind"]), + } + } + if resolution.Expression == "" && resolution.ResolvedKey == "" { + continue + } + result = append(result, resolution) + } + return result +} + +func typeResolutionsToAny(resolutions []typectx.Resolution) []any { + result := make([]any, 0, len(resolutions)) + for _, item := range resolutions { + if item.Expression == "" && item.ResolvedKey == "" { + continue + } + result = append(result, map[string]any{ + "Expression": item.Expression, + "Target": item.Target, + "ResolvedKey": item.ResolvedKey, + "MatchKind": item.MatchKind, + "Provenance": map[string]any{ + "Package": item.Provenance.Package, + "File": item.Provenance.File, + "Kind": item.Provenance.Kind, + }, + }) + } + if len(result) == 0 { + return nil + } + return result +} + +func deepClone(value any) any { + switch actual := value.(type) { + case map[string]any: + out := make(map[string]any, len(actual)) + for k, v := range actual { + out[k] = deepClone(v) + } + return out + case map[any]any: + out := make(map[string]any, len(actual)) + for k, v := range actual { + out[fmt.Sprint(k)] = deepClone(v) + } + return out + case []any: + out := make([]any, len(actual)) + for i, item := range actual { + out[i] = deepClone(item) + } + return out + default: + return actual + } +} + +func asMap(raw any) map[string]any { + if value, ok := raw.(map[string]any); ok { + return value + } + if value, ok := raw.(map[any]any); ok { + out := make(map[string]any, len(value)) + for k, item := range value { + out[fmt.Sprint(k)] = item + } + return out + } + return nil +} + +func asSlice(raw any) []any { + if value, ok := raw.([]any); ok { + return value + } + return nil +} + +func asString(raw any) string { + if raw == nil { + return "" + } + if value, ok := raw.(string); ok { + return value + } + return fmt.Sprint(raw) +} diff --git a/repository/shape/dql/shape/convert_test.go b/repository/shape/dql/shape/convert_test.go new file mode 100644 index 000000000..0e41a43fd --- /dev/null +++ b/repository/shape/dql/shape/convert_test.go @@ -0,0 +1,121 @@ +package shape + +import ( + "reflect" + "testing" + + "github.com/viant/datly/repository/shape/dql/ir" + "github.com/viant/datly/repository/shape/typectx" +) + +func TestFromIRToIR_RoundTripPreservesRoot(t *testing.T) { + source := &ir.Document{Root: map[string]any{ + "Routes": []any{ + map[string]any{ + "Name": "Route", + "URI": "/x", + "Method": "GET", + "View": map[string]any{ + "Ref": "rootView", + }, + }, + }, + "Resource": map[string]any{ + "Views": []any{ + map[string]any{ + "Name": "rootView", + "Table": "T", + "Connector": map[string]any{ + "Ref": "main", + }, + "Template": map[string]any{ + "Source": "SELECT * FROM T", + }, + }, + }, + }, + }} + shapeDoc, err := FromIR(source) + if err != nil { + t.Fatalf("FromIR failed: %v", err) + } + if shapeDoc == nil || len(shapeDoc.Routes) != 1 || shapeDoc.Resource == nil || len(shapeDoc.Resource.Views) != 1 { + t.Fatalf("unexpected shape projection: %+v", shapeDoc) + } + target, err := ToIR(shapeDoc) + if err != nil { + t.Fatalf("ToIR failed: %v", err) + } + if !reflect.DeepEqual(source.Root, target.Root) { + t.Fatalf("round-trip mismatch") + } +} + +func TestToIR_FromIR_TypeContextRoundTrip(t *testing.T) { + doc := &Document{ + Root: map[string]any{ + "Routes": []any{}, + "Resource": map[string]any{}, + }, + TypeContext: &typectx.Context{ + DefaultPackage: "mdp/performance", + Imports: []typectx.Import{ + {Alias: "perf", Package: "github.com/acme/mdp/performance"}, + }, + }, + } + irDoc, err := ToIR(doc) + if err != nil { + t.Fatalf("ToIR failed: %v", err) + } + shapeDoc, err := FromIR(irDoc) + if err != nil { + t.Fatalf("FromIR failed: %v", err) + } + if shapeDoc.TypeContext == nil { + t.Fatalf("expected type context") + } + if shapeDoc.TypeContext.DefaultPackage != "mdp/performance" { + t.Fatalf("unexpected default package: %s", shapeDoc.TypeContext.DefaultPackage) + } + if len(shapeDoc.TypeContext.Imports) != 1 { + t.Fatalf("unexpected imports count: %d", len(shapeDoc.TypeContext.Imports)) + } +} + +func TestToIR_FromIR_TypeResolutionsRoundTrip(t *testing.T) { + doc := &Document{ + Root: map[string]any{ + "Routes": []any{}, + "Resource": map[string]any{}, + }, + TypeResolutions: []typectx.Resolution{ + { + Expression: "Order", + Target: "main.ID", + ResolvedKey: "github.com/acme/mdp/performance.Order", + MatchKind: "default_package", + Provenance: typectx.Provenance{ + Package: "github.com/acme/mdp/performance", + File: "/repo/mdp/performance/order.go", + Kind: "resource_type", + }, + }, + }, + } + irDoc, err := ToIR(doc) + if err != nil { + t.Fatalf("ToIR failed: %v", err) + } + shapeDoc, err := FromIR(irDoc) + if err != nil { + t.Fatalf("FromIR failed: %v", err) + } + if len(shapeDoc.TypeResolutions) != 1 { + t.Fatalf("unexpected type resolutions count: %d", len(shapeDoc.TypeResolutions)) + } + got := shapeDoc.TypeResolutions[0] + if got.ResolvedKey != "github.com/acme/mdp/performance.Order" || got.Provenance.Kind != "resource_type" { + t.Fatalf("unexpected resolution: %+v", got) + } +} diff --git a/repository/shape/dql/shape/model.go b/repository/shape/dql/shape/model.go new file mode 100644 index 000000000..e975cfca3 --- /dev/null +++ b/repository/shape/dql/shape/model.go @@ -0,0 +1,114 @@ +package shape + +import ( + "fmt" + + "github.com/viant/datly/repository/shape/typectx" + "github.com/viant/sqlparser/query" +) + +// Severity represents diagnostic severity level. +type Severity string + +const ( + SeverityError Severity = "error" + SeverityWarning Severity = "warning" + SeverityInfo Severity = "info" +) + +// Position identifies a byte offset and human-readable line/character location. +type Position struct { + Offset int + Line int + Char int +} + +// Span captures the location range for one diagnostic. +type Span struct { + Start Position + End Position +} + +// Diagnostic represents one compile/parse issue with precise location. +type Diagnostic struct { + Code string + Severity Severity + Message string + Hint string + Span Span +} + +// Directives captures special #set(...) directives parsed from DQL. +type Directives struct { + Meta string + DefaultConnector string + Cache *CacheDirective + MCP *MCPDirective + Route *RouteDirective + JSONMarshalType string + JSONUnmarshalType string + XMLUnmarshalType string + Format string + DateFormat string + CaseFormat string +} + +type CacheDirective struct { + Enabled bool + TTL string +} + +type MCPDirective struct { + Name string + Description string + DescriptionPath string +} + +type RouteDirective struct { + URI string + Methods []string +} + +type Route struct { + Name string + URI string + Method string + ViewRef string + Description string +} + +type Resource struct { + Views []*View +} + +type View struct { + Name string + Table string + Module string + ConnectorRef string +} + +// Document represents parsed DQL model used by shape compiler and xgen. +type Document struct { + Raw string + SQL string + Query *query.Select + TypeContext *typectx.Context + Directives *Directives + Routes []*Route + Resource *Resource + Root map[string]any + TypeResolutions []typectx.Resolution + Diagnostics []*Diagnostic +} + +// Error returns a compact human-readable diagnostic string. +func (d *Diagnostic) Error() string { + if d == nil { + return "" + } + if d.Code == "" { + return fmt.Sprintf("%s at line %d, char %d", d.Message, d.Span.Start.Line, d.Span.Start.Char) + } + return fmt.Sprintf("%s: %s at line %d, char %d", d.Code, d.Message, d.Span.Start.Line, d.Span.Start.Char) +} diff --git a/repository/shape/dql/statement/parity_test.go b/repository/shape/dql/statement/parity_test.go new file mode 100644 index 000000000..f8b734114 --- /dev/null +++ b/repository/shape/dql/statement/parity_test.go @@ -0,0 +1,33 @@ +package statement + +import ( + "testing" + + "github.com/stretchr/testify/assert" + legacy "github.com/viant/datly/internal/translator/parser" +) + +func TestStatements_ParityWithLegacyScanner(t *testing.T) { + testCases := []struct { + name string + sql string + }{ + {name: "read", sql: "SELECT id FROM orders"}, + {name: "exec update", sql: "UPDATE orders SET id = 1"}, + {name: "mixed", sql: "SELECT id FROM orders\nUPDATE orders SET id = 1"}, + {name: "service insert", sql: `$sql.Insert("orders", $rec)`}, + {name: "nop", sql: `$Nop($x)`}, + } + for _, testCase := range testCases { + t.Run(testCase.name, func(t *testing.T) { + current := New(testCase.sql) + old := legacy.NewStatements(testCase.sql) + assert.Equal(t, len(old), len(current)) + for i := 0; i < len(old) && i < len(current); i++ { + assert.Equal(t, old[i].IsExec, current[i].IsExec) + assert.Equal(t, old[i].Start, current[i].Start) + assert.Equal(t, old[i].End, current[i].End) + } + }) + } +} diff --git a/repository/shape/dql/statement/parser.go b/repository/shape/dql/statement/parser.go new file mode 100644 index 000000000..eaf7a9bdd --- /dev/null +++ b/repository/shape/dql/statement/parser.go @@ -0,0 +1,201 @@ +package statement + +import ( + "strings" + + "github.com/viant/datly/view/keywords" + "github.com/viant/parsly" + "github.com/viant/parsly/matcher" + aexpr "github.com/viant/velty/ast/expr" + veltyparser "github.com/viant/velty/parser" +) + +const ( + stmtWhitespaceToken = iota + stmtExprGroupToken + stmtExecToken + stmtReadToken + stmtExprToken + stmtExprEndToken + stmtAnyToken +) + +var ( + stmtWhitespaceMatcher = parsly.NewToken(stmtWhitespaceToken, "Whitespace", matcher.NewWhiteSpace()) + stmtExprGroupMatcher = parsly.NewToken(stmtExprGroupToken, "( ... )", matcher.NewBlock('(', ')', '\\')) + stmtExecMatcher = parsly.NewToken(stmtExecToken, "Exec", matcher.NewFragmentsFold([]byte("insert"), []byte("update"), []byte("delete"), []byte("call"), []byte("begin"))) + stmtReadMatcher = parsly.NewToken(stmtReadToken, "Read", matcher.NewFragmentsFold([]byte("select"))) + stmtExprMatcher = parsly.NewToken(stmtExprToken, "Expression", matcher.NewFragments([]byte("#set"), []byte("#foreach"), []byte("#if"))) + stmtExprEndMatcher = parsly.NewToken(stmtExprEndToken, "#end", matcher.NewFragmentsFold([]byte("#end"))) + stmtAnyMatcher = parsly.NewToken(stmtAnyToken, "Any", &anyMatcher{}) +) + +type anyMatcher struct{} + +func (a *anyMatcher) Match(cursor *parsly.Cursor) int { + if cursor.Pos < cursor.InputSize { + return 1 + } + return 0 +} + +func parseStatements(sqlText string) Statements { + cursor := parsly.NewCursor("", []byte(sqlText), 0) + var ( + result Statements + current *Statement + ) + for cursor.Pos < cursor.InputSize { + if consumeCommentOrQuoted(sqlText, cursor) { + continue + } + if cursor.Input[cursor.Pos] == '(' { + if block := cursor.MatchOne(stmtExprGroupMatcher); block.Code == stmtExprGroupToken { + continue + } + } + _ = cursor.MatchOne(stmtWhitespaceMatcher) + beforeMatch := cursor.Pos + matched := cursor.MatchAfterOptional(stmtWhitespaceMatcher, stmtExprMatcher, stmtExprEndMatcher, stmtExecMatcher, stmtReadMatcher, stmtAnyMatcher) + switch matched.Code { + case stmtExprToken: + _ = cursor.MatchAfterOptional(stmtWhitespaceMatcher, stmtExprGroupMatcher) + case stmtExecToken, stmtReadToken: + isExec := matched.Code == stmtExecToken + kind := KindRead + if isExec { + kind = KindExec + } + if nextWhitespace(cursor) { + if current != nil { + current.End = beforeMatch + } + current = &Statement{ + Start: beforeMatch, + End: -1, + Kind: kind, + IsExec: isExec, + } + result = append(result, current) + } + case stmtAnyToken: + kind, method, ok := getStmtSelector(matched, cursor) + if ok { + if current != nil { + current.End = beforeMatch + } + current = &Statement{ + Start: beforeMatch, + End: -1, + IsExec: true, + Kind: kind, + SelectorMethod: method, + } + result = append(result, current) + } + if !ok { + advanceToWhitespace(cursor) + } + _ = nextWhitespace(cursor) + } + } + if current != nil { + current.End = len(sqlText) + } + if len(result) == 0 { + kind, isExec, selector := inferDefaultKind(sqlText) + result = append(result, &Statement{ + Start: 0, + End: len(sqlText), + Kind: kind, + IsExec: isExec, + SelectorMethod: selector, + }) + } + return result +} + +func consumeCommentOrQuoted(sqlText string, cursor *parsly.Cursor) bool { + if cursor.Pos >= cursor.InputSize { + return false + } + if startsWithAt(sqlText, cursor.Pos, "--") { + cursor.Pos += 2 + for cursor.Pos < cursor.InputSize && sqlText[cursor.Pos] != '\n' { + cursor.Pos++ + } + return true + } + if startsWithAt(sqlText, cursor.Pos, "/*") { + cursor.Pos += 2 + for cursor.Pos+1 < cursor.InputSize { + if sqlText[cursor.Pos] == '*' && sqlText[cursor.Pos+1] == '/' { + cursor.Pos += 2 + return true + } + cursor.Pos++ + } + cursor.Pos = cursor.InputSize + return true + } + switch sqlText[cursor.Pos] { + case '\'', '"', '`': + quote := sqlText[cursor.Pos] + cursor.Pos++ + for cursor.Pos < cursor.InputSize { + ch := sqlText[cursor.Pos] + cursor.Pos++ + if ch == quote && (cursor.Pos < 2 || sqlText[cursor.Pos-2] != '\\') { + break + } + } + return true + } + return false +} + +func startsWithAt(text string, offset int, candidate string) bool { + if offset < 0 || offset+len(candidate) > len(text) { + return false + } + return text[offset:offset+len(candidate)] == candidate +} + +func getStmtSelector(matched *parsly.TokenMatch, cursor *parsly.Cursor) (string, string, bool) { + if matched.Text(cursor) != "$" { + return "", "", false + } + selector, err := veltyparser.MatchSelector(cursor) + if err != nil || selector == nil { + return "", "", false + } + if strings.EqualFold(selector.ID, "Nop") { + return KindExec, "Nop", true + } + if !strings.EqualFold(selector.ID, keywords.KeySQL) || selector.X == nil { + return "", "", false + } + aSelector, ok := selector.X.(*aexpr.Select) + if !ok { + return "", "", false + } + if aSelector.ID != "Insert" && aSelector.ID != "Update" { + return "", "", false + } + return KindService, aSelector.ID, true +} + +func nextWhitespace(cursor *parsly.Cursor) bool { + before := cursor.Pos + _ = cursor.MatchOne(stmtWhitespaceMatcher) + return before != cursor.Pos +} + +func advanceToWhitespace(cursor *parsly.Cursor) { + for cursor.Pos < cursor.InputSize { + if matcher.IsWhiteSpace(cursor.Input[cursor.Pos]) { + return + } + cursor.Pos++ + } +} diff --git a/repository/shape/dql/statement/statement.go b/repository/shape/dql/statement/statement.go new file mode 100644 index 000000000..58d8836b9 --- /dev/null +++ b/repository/shape/dql/statement/statement.go @@ -0,0 +1,137 @@ +package statement + +import ( + "strings" + + "github.com/viant/sqlparser" +) + +const ( + KindRead = "read" + KindExec = "exec" + KindService = "service" +) + +type Statement struct { + Start int + End int + Kind string + IsExec bool + SelectorMethod string + Table string +} + +type Statements []*Statement + +func (s Statements) IsExec() bool { + if len(s) == 0 { + return true + } + for _, item := range s { + if item != nil && item.IsExec { + return true + } + } + return false +} + +func (s Statements) DMLTables(rawSQL string) []string { + var ( + tables = map[string]bool{} + result []string + ) + for _, statement := range s { + if statement == nil || !statement.IsExec { + continue + } + sqlText := slice(rawSQL, statement.Start, statement.End) + if statement.Kind == KindService { + if table := firstQuotedArgument(sqlText); table != "" { + statement.Table = table + if !tables[table] { + result = append(result, table) + } + tables[table] = true + continue + } + } + lower := strings.ToLower(sqlText) + switch { + case strings.Contains(lower, "insert"): + if parsed, _ := sqlparser.ParseInsert(sqlText); parsed != nil && parsed.Target.X != nil { + statement.Table = strings.TrimSpace(sqlparser.Stringify(parsed.Target.X)) + } + case strings.Contains(lower, "update"): + if parsed, _ := sqlparser.ParseUpdate(sqlText); parsed != nil && parsed.Target.X != nil { + statement.Table = strings.TrimSpace(sqlparser.Stringify(parsed.Target.X)) + } + case strings.Contains(lower, "delete"): + if parsed, _ := sqlparser.ParseDelete(sqlText); parsed != nil && parsed.Target.X != nil { + statement.Table = strings.TrimSpace(sqlparser.Stringify(parsed.Target.X)) + } + } + if statement.Table == "" { + continue + } + if !tables[statement.Table] { + result = append(result, statement.Table) + } + tables[statement.Table] = true + } + return result +} + +func New(sqlText string) Statements { + if strings.TrimSpace(sqlText) == "" { + return Statements{&Statement{Start: 0, End: 0}} + } + return parseStatements(sqlText) +} + +func slice(input string, start, end int) string { + if start < 0 { + start = 0 + } + if end < start { + end = start + } + if end > len(input) { + end = len(input) + } + return input[start:end] +} + +func firstQuotedArgument(sqlText string) string { + index := strings.Index(sqlText, `"`) + if index == -1 { + return "" + } + tail := sqlText[index+1:] + end := strings.Index(tail, `"`) + if end == -1 { + return "" + } + return strings.TrimSpace(tail[:end]) +} + +func inferDefaultKind(sqlText string) (string, bool, string) { + trimmed := strings.TrimSpace(strings.ToLower(sqlText)) + switch { + case strings.HasPrefix(trimmed, "select"): + return KindRead, false, "" + case strings.HasPrefix(trimmed, "insert"), + strings.HasPrefix(trimmed, "update"), + strings.HasPrefix(trimmed, "delete"), + strings.HasPrefix(trimmed, "call"), + strings.HasPrefix(trimmed, "begin"): + return KindExec, true, "" + case strings.HasPrefix(trimmed, "$sql.insert"): + return KindService, true, "Insert" + case strings.HasPrefix(trimmed, "$sql.update"): + return KindService, true, "Update" + case strings.HasPrefix(trimmed, "$nop("): + return KindExec, true, "Nop" + default: + return "", false, "" + } +} diff --git a/repository/shape/dql/statement/statement_test.go b/repository/shape/dql/statement/statement_test.go new file mode 100644 index 000000000..45d7d14a3 --- /dev/null +++ b/repository/shape/dql/statement/statement_test.go @@ -0,0 +1,77 @@ +package statement + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestNew_ReadStatement(t *testing.T) { + stmts := New("SELECT id FROM orders") + require.Len(t, stmts, 1) + assert.False(t, stmts[0].IsExec) + assert.Equal(t, KindRead, stmts[0].Kind) +} + +func TestNew_ExecStatements(t *testing.T) { + sqlText := "INSERT INTO orders(id) VALUES (1)\nUPDATE orders SET name = 'x' WHERE id = 1" + stmts := New(sqlText) + require.Len(t, stmts, 2) + assert.True(t, stmts[0].IsExec) + assert.True(t, stmts[1].IsExec) + assert.Equal(t, KindExec, stmts[0].Kind) +} + +func TestNew_ServiceExec(t *testing.T) { + stmts := New(`$sql.Insert("ORDERS", $rec)`) + require.Len(t, stmts, 1) + assert.True(t, stmts[0].IsExec) + assert.Equal(t, KindService, stmts[0].Kind) + assert.Equal(t, "Insert", stmts[0].SelectorMethod) +} + +func TestStatements_DMLTables(t *testing.T) { + stmts := New(`INSERT INTO orders(id) VALUES (1) +UPDATE orders SET id = 2 +DELETE FROM items WHERE id = 1 +$sql.Insert("ORDERS_AUDIT", $rec)`) + tables := stmts.DMLTables(`INSERT INTO orders(id) VALUES (1) +UPDATE orders SET id = 2 +DELETE FROM items WHERE id = 1 +$sql.Insert("ORDERS_AUDIT", $rec)`) + assert.Equal(t, []string{"orders", "items", "ORDERS_AUDIT"}, tables) +} + +func TestNew_IgnoreKeywordsInCommentsAndStrings(t *testing.T) { + sqlText := "-- insert into x\nSELECT 'update x' as txt FROM orders" + stmts := New(sqlText) + require.Len(t, stmts, 1) + assert.Equal(t, KindRead, stmts[0].Kind) + assert.False(t, stmts[0].IsExec) +} + +func TestNew_DefaultUnknownIsNotExec(t *testing.T) { + stmts := New("$foo.Bar($baz)") + require.Len(t, stmts, 1) + assert.Equal(t, "", stmts[0].Kind) + assert.False(t, stmts[0].IsExec) +} + +func TestNew_DefaultNopIsExec(t *testing.T) { + stmts := New("$Nop($Unsafe.Id)") + require.Len(t, stmts, 1) + assert.Equal(t, KindExec, stmts[0].Kind) + assert.True(t, stmts[0].IsExec) + assert.Equal(t, "Nop", stmts[0].SelectorMethod) +} + +func TestNew_NestedSubquerySelect_IsSingleReadStatement(t *testing.T) { + sqlText := `SELECT session.* +FROM (SELECT * FROM session WHERE user_id = $criteria.AppendBinding($Unsafe.Jwt.UserID)) session +JOIN (SELECT * FROM session/attributes) attribute ON attribute.user_id = session.user_id` + stmts := New(sqlText) + require.Len(t, stmts, 1) + assert.Equal(t, KindRead, stmts[0].Kind) + assert.False(t, stmts[0].IsExec) +} diff --git a/repository/shape/dql_engine_test.go b/repository/shape/dql_engine_test.go new file mode 100644 index 000000000..a7a40fff5 --- /dev/null +++ b/repository/shape/dql_engine_test.go @@ -0,0 +1,66 @@ +package shape_test + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + shape "github.com/viant/datly/repository/shape" + shapeCompile "github.com/viant/datly/repository/shape/compile" + shapeLoad "github.com/viant/datly/repository/shape/load" +) + +func TestEngine_LoadDQLViews(t *testing.T) { + engine := shape.New( + shape.WithCompiler(shapeCompile.New()), + shape.WithLoader(shapeLoad.New()), + shape.WithName("/v1/api/reports/orders"), + ) + artifacts, err := engine.LoadDQLViews(context.Background(), "SELECT id FROM ORDERS t") + require.NoError(t, err) + require.NotNil(t, artifacts) + require.Len(t, artifacts.Views, 1) + assert.Equal(t, "t", artifacts.Views[0].Name) +} + +func TestEngine_LoadDQLComponent(t *testing.T) { + engine := shape.New( + shape.WithCompiler(shapeCompile.New()), + shape.WithLoader(shapeLoad.New()), + shape.WithName("/v1/api/reports/orders"), + ) + artifact, err := engine.LoadDQLComponent(context.Background(), "SELECT id FROM ORDERS t") + require.NoError(t, err) + require.NotNil(t, artifact) + require.NotNil(t, artifact.Component) + + component, ok := shapeLoad.ComponentFrom(artifact) + require.True(t, ok) + assert.Equal(t, "/v1/api/reports/orders", component.Name) + assert.Equal(t, "t", component.RootView) +} + +func TestEngine_LoadDQLComponent_DeclarationMetadata(t *testing.T) { + engine := shape.New( + shape.WithCompiler(shapeCompile.New()), + shape.WithLoader(shapeLoad.New()), + shape.WithName("/v1/api/reports/orders"), + ) + dql := ` +#set($_ = $limit(view/limit).WithPredicate('ByID','id = ?', 1).QuerySelector('items') /* SELECT id FROM ORDERS o */) +SELECT id FROM ORDERS t` + artifact, err := engine.LoadDQLComponent(context.Background(), dql) + require.NoError(t, err) + require.NotNil(t, artifact) + component, ok := shapeLoad.ComponentFrom(artifact) + require.True(t, ok) + require.NotNil(t, component.Declarations) + require.NotNil(t, component.QuerySelectors) + require.NotNil(t, component.Predicates) + assert.Equal(t, []string{"o"}, component.QuerySelectors["items"]) + require.NotNil(t, component.Declarations["o"]) + assert.Equal(t, "items", component.Declarations["o"].QuerySelector) + require.NotEmpty(t, component.Predicates["o"]) + assert.Equal(t, "ByID", component.Predicates["o"][0].Name) +} diff --git a/repository/shape/engine_compile_options_test.go b/repository/shape/engine_compile_options_test.go new file mode 100644 index 000000000..28de1a8ff --- /dev/null +++ b/repository/shape/engine_compile_options_test.go @@ -0,0 +1,77 @@ +package shape + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +type captureCompiler struct { + last CompileOptions +} + +func (c *captureCompiler) Compile(_ context.Context, source *Source, opts ...CompileOption) (*PlanResult, error) { + compiled := &CompileOptions{} + for _, opt := range opts { + if opt != nil { + opt(compiled) + } + } + c.last = *compiled + return &PlanResult{Source: source}, nil +} + +func TestEngine_Compile_UsesLegacyParityDefaults(t *testing.T) { + compiler := &captureCompiler{} + engine := New(WithCompiler(compiler)) + + _, err := engine.compile(context.Background(), &Source{Name: "orders", DQL: "SELECT 1"}) + require.NoError(t, err) + assert.False(t, compiler.last.Strict) + assert.Equal(t, CompileProfileCompat, compiler.last.Profile) + assert.Equal(t, CompileMixedModeExecWins, compiler.last.MixedMode) + assert.Equal(t, CompileUnknownNonReadWarn, compiler.last.UnknownNonReadMode) + assert.Equal(t, CompileColumnDiscoveryAuto, compiler.last.ColumnDiscoveryMode) +} + +func TestEngine_Compile_ForwardsCustomDefaults(t *testing.T) { + compiler := &captureCompiler{} + engine := New( + WithCompiler(compiler), + WithStrict(true), + WithCompileProfileDefault(CompileProfileStrict), + WithMixedModeDefault(CompileMixedModeReadWins), + WithUnknownNonReadModeDefault(CompileUnknownNonReadError), + WithColumnDiscoveryModeDefault(CompileColumnDiscoveryOff), + ) + + _, err := engine.compile(context.Background(), &Source{Name: "orders", DQL: "SELECT 1"}) + require.NoError(t, err) + assert.True(t, compiler.last.Strict) + assert.Equal(t, CompileProfileStrict, compiler.last.Profile) + assert.Equal(t, CompileMixedModeReadWins, compiler.last.MixedMode) + assert.Equal(t, CompileUnknownNonReadError, compiler.last.UnknownNonReadMode) + assert.Equal(t, CompileColumnDiscoveryOff, compiler.last.ColumnDiscoveryMode) +} + +func TestEngine_Compile_LegacyDefaultsOption(t *testing.T) { + compiler := &captureCompiler{} + engine := New( + WithCompiler(compiler), + WithStrict(true), + WithCompileProfileDefault(CompileProfileStrict), + WithMixedModeDefault(CompileMixedModeReadWins), + WithUnknownNonReadModeDefault(CompileUnknownNonReadError), + WithLegacyTranslatorDefaults(), + ) + + _, err := engine.compile(context.Background(), &Source{Name: "orders", DQL: "SELECT 1"}) + require.NoError(t, err) + assert.False(t, compiler.last.Strict) + assert.Equal(t, CompileProfileCompat, compiler.last.Profile) + assert.Equal(t, CompileMixedModeExecWins, compiler.last.MixedMode) + assert.Equal(t, CompileUnknownNonReadWarn, compiler.last.UnknownNonReadMode) + assert.Equal(t, CompileColumnDiscoveryAuto, compiler.last.ColumnDiscoveryMode) +} diff --git a/repository/shape/errors.go b/repository/shape/errors.go new file mode 100644 index 000000000..852313b6c --- /dev/null +++ b/repository/shape/errors.go @@ -0,0 +1,12 @@ +package shape + +import "errors" + +var ( + ErrNilSource = errors.New("shape: source was nil") + ErrNilDQL = errors.New("shape: dql was empty") + ErrScannerNotConfigured = errors.New("shape: scanner was not configured") + ErrPlannerNotConfigured = errors.New("shape: planner was not configured") + ErrLoaderNotConfigured = errors.New("shape: loader was not configured") + ErrCompilerNotConfigured = errors.New("shape: compiler was not configured") +) diff --git a/repository/shape/load/columns.go b/repository/shape/load/columns.go new file mode 100644 index 000000000..147a2a5b4 --- /dev/null +++ b/repository/shape/load/columns.go @@ -0,0 +1,92 @@ +package load + +import ( + "reflect" + "strings" + + "github.com/viant/datly/view" +) + +var mapStringInterface = reflect.TypeOf(map[string]interface{}{}) + +// inferColumnsFromType extracts column descriptors from a statically-inferred struct type. +// Returns nil when rType is nil, non-struct, or the untyped map[string]interface{} fallback. +func inferColumnsFromType(rType reflect.Type) []*view.Column { + if rType == nil { + return nil + } + // Unwrap slice / pointer wrappers + for rType.Kind() == reflect.Ptr || rType.Kind() == reflect.Slice { + rType = rType.Elem() + } + if rType.Kind() != reflect.Struct { + return nil + } + // Skip the untyped fallback used when columns are unknown + if rType == mapStringInterface { + return nil + } + cols := make([]*view.Column, 0, rType.NumField()) + for i := 0; i < rType.NumField(); i++ { + f := rType.Field(i) + if !f.IsExported() { + continue + } + colName := sqlxColumnName(f) + if colName == "" { + colName = f.Name + } + cols = append(cols, &view.Column{ + Name: colName, + DataType: reflectDataType(f.Type), + }) + } + return cols +} + +// sqlxColumnName reads the sqlx struct tag to get the database column name. +func sqlxColumnName(f reflect.StructField) string { + tag := f.Tag.Get("sqlx") + if tag == "" { + return "" + } + for _, part := range strings.Split(tag, ",") { + part = strings.TrimSpace(part) + if strings.HasPrefix(part, "name=") { + return strings.TrimPrefix(part, "name=") + } + } + return "" +} + +// reflectDataType maps a Go reflect.Type to a datly column DataType string. +func reflectDataType(t reflect.Type) string { + for t.Kind() == reflect.Ptr { + t = t.Elem() + } + switch t.Kind() { + case reflect.String: + return "string" + case reflect.Bool: + return "bool" + case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int: + return "int" + case reflect.Int64: + return "int64" + case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint: + return "int" + case reflect.Uint64: + return "int64" + case reflect.Float32: + return "float32" + case reflect.Float64: + return "float64" + case reflect.Slice: + if t.Elem().Kind() == reflect.Uint8 { + return "[]byte" + } + return "[]" + reflectDataType(t.Elem()) + default: + return "string" + } +} diff --git a/repository/shape/load/doc.go b/repository/shape/load/doc.go new file mode 100644 index 000000000..1800597cc --- /dev/null +++ b/repository/shape/load/doc.go @@ -0,0 +1,2 @@ +// Package load defines materialization responsibilities for runtime artifacts. +package load diff --git a/repository/shape/load/errors.go b/repository/shape/load/errors.go new file mode 100644 index 000000000..51f15d6ab --- /dev/null +++ b/repository/shape/load/errors.go @@ -0,0 +1,7 @@ +package load + +import "errors" + +var ( + ErrEmptyViewPlan = errors.New("shape load: no views available in plan") +) diff --git a/repository/shape/load/loader.go b/repository/shape/load/loader.go new file mode 100644 index 000000000..03277feff --- /dev/null +++ b/repository/shape/load/loader.go @@ -0,0 +1,469 @@ +package load + +import ( + "context" + "fmt" + "reflect" + "strings" + "time" + + "github.com/viant/datly/repository/shape" + dqlshape "github.com/viant/datly/repository/shape/dql/shape" + "github.com/viant/datly/repository/shape/plan" + "github.com/viant/datly/repository/shape/typectx" + shapevalidate "github.com/viant/datly/repository/shape/validate" + "github.com/viant/datly/shared" + "github.com/viant/datly/view" + "github.com/viant/datly/view/extension" + "github.com/viant/datly/view/state" +) + +// Loader materializes runtime view artifacts from normalized shape plan. +type Loader struct{} + +// New returns shape loader implementation. +func New() *Loader { + return &Loader{} +} + +// LoadViews implements shape.Loader. +func (l *Loader) LoadViews(ctx context.Context, planned *shape.PlanResult, _ ...shape.LoadOption) (*shape.ViewArtifacts, error) { + if err := ctx.Err(); err != nil { + return nil, err + } + pResult, resource, err := l.materialize(planned) + if err != nil { + return nil, err + } + if len(pResult.Views) == 0 { + return nil, ErrEmptyViewPlan + } + return &shape.ViewArtifacts{Resource: resource, Views: resource.Views}, nil +} + +// LoadComponent implements shape.Loader. +func (l *Loader) LoadComponent(ctx context.Context, planned *shape.PlanResult, _ ...shape.LoadOption) (*shape.ComponentArtifact, error) { + if err := ctx.Err(); err != nil { + return nil, err + } + pResult, resource, err := l.materialize(planned) + if err != nil { + return nil, err + } + if len(pResult.Views) == 0 { + return nil, ErrEmptyViewPlan + } + component := buildComponent(planned.Source, pResult) + return &shape.ComponentArtifact{ + Resource: resource, + Component: component, + }, nil +} + +func (l *Loader) materialize(planned *shape.PlanResult) (*plan.Result, *view.Resource, error) { + if planned == nil || planned.Source == nil { + return nil, nil, shape.ErrNilSource + } + pResult, ok := plan.ResultFrom(planned) + if !ok { + return nil, nil, fmt.Errorf("shape load: unsupported plan kind %q", planned.Plan.ShapeSpecKind()) + } + resource := view.EmptyResource() + if pResult.EmbedFS != nil { + resource.SetFSEmbedder(state.NewFSEmbedder(pResult.EmbedFS)) + } + for _, item := range pResult.Views { + aView, err := materializeView(item) + if err != nil { + return nil, nil, err + } + resource.AddViews(aView) + } + if err := shapevalidate.ValidateRelations(resource, resource.Views...); err != nil { + return nil, nil, err + } + // Gap 7: apply global cache TTL directive to root view. + if pResult.Directives != nil && pResult.Directives.Cache != nil { + if ttl := strings.TrimSpace(pResult.Directives.Cache.TTL); ttl != "" { + if dur, err := time.ParseDuration(ttl); err == nil && dur > 0 { + ttlMs := int(dur.Milliseconds()) + if rootPlan := pickRootView(pResult.Views); rootPlan != nil { + for _, rv := range resource.Views { + if rv != nil && rv.Name == rootPlan.Name { + if rv.Cache == nil { + rv.Cache = &view.Cache{} + } + rv.Cache.TimeToLiveMs = ttlMs + break + } + } + } + } + } + } + return pResult, resource, nil +} + +func buildComponent(source *shape.Source, pResult *plan.Result) *Component { + component := &Component{Method: "GET"} + if source != nil { + component.Name = source.Name + component.URI = source.Name + } + applyViewMeta(component, pResult.Views) + applyStateBuckets(component, pResult.States) + component.Input = append(component.Input, synthesizePredicateStates(component.Input, component.Predicates)...) + component.TypeContext = cloneTypeContext(pResult.TypeContext) + component.Directives = cloneDirectives(pResult.Directives) + component.ColumnsDiscovery = pResult.ColumnsDiscovery + return component +} + +// applyViewMeta populates the component with view names, declarations, relations, +// query selectors, predicate maps, and root view from the plan view list. +func applyViewMeta(component *Component, views []*plan.View) { + for _, aView := range views { + if aView == nil { + continue + } + component.Views = append(component.Views, aView.Name) + if aView.Declaration != nil { + indexViewDeclaration(component, aView.Name, aView.Declaration) + } + if len(aView.Relations) > 0 { + component.Relations = append(component.Relations, aView.Relations...) + component.ViewRelations = append(component.ViewRelations, toViewRelations(aView.Relations)...) + } + } + if rootView := pickRootView(views); rootView != nil { + component.RootView = rootView.Name + if component.Name == "" { + component.Name = rootView.Name + } + } +} + +// indexViewDeclaration registers the declaration's query selector and predicates +// on the component index maps, creating them on demand. +func indexViewDeclaration(component *Component, viewName string, decl *plan.ViewDeclaration) { + if component.Declarations == nil { + component.Declarations = map[string]*plan.ViewDeclaration{} + } + component.Declarations[viewName] = decl + if selector := strings.TrimSpace(decl.QuerySelector); selector != "" { + if component.QuerySelectors == nil { + component.QuerySelectors = map[string][]string{} + } + component.QuerySelectors[selector] = append(component.QuerySelectors[selector], viewName) + } + if len(decl.Predicates) > 0 { + if component.Predicates == nil { + component.Predicates = map[string][]*plan.ViewPredicate{} + } + component.Predicates[viewName] = append(component.Predicates[viewName], decl.Predicates...) + } +} + +// applyStateBuckets sorts plan states into the typed buckets on the component +// (Input, Output, Meta, Async, Other) based on the state's location kind. +func applyStateBuckets(component *Component, states []*plan.State) { + for _, item := range states { + if item == nil { + continue + } + kind := state.Kind(strings.ToLower(item.KindString())) + inName := item.InName() + if kind == "" && inName == "" { + component.Other = append(component.Other, item) + continue + } + switch kind { + case state.KindQuery, state.KindPath, state.KindHeader, state.KindRequestBody, + state.KindForm, state.KindCookie, state.KindRequest, "": + component.Input = append(component.Input, item) + case state.KindOutput: + component.Output = append(component.Output, item) + case state.KindMeta: + component.Meta = append(component.Meta, item) + case state.KindAsync: + component.Async = append(component.Async, item) + default: + component.Other = append(component.Other, item) + } + } +} + +// synthesizePredicateStates creates query parameters for view-level predicates whose +// source parameter is not already present in the input state list. +func synthesizePredicateStates(input []*plan.State, predicates map[string][]*plan.ViewPredicate) []*plan.State { + if len(predicates) == 0 { + return nil + } + declared := make(map[string]bool, len(input)) + for _, s := range input { + if s != nil { + declared[strings.ToLower(strings.TrimPrefix(strings.TrimSpace(s.Name), "$"))] = true + } + } + var result []*plan.State + for _, viewPredicates := range predicates { + for _, vp := range viewPredicates { + if vp == nil { + continue + } + src := strings.TrimPrefix(strings.TrimSpace(vp.Source), "$") + if src == "" || declared[strings.ToLower(src)] { + continue + } + result = append(result, &plan.State{ + Parameter: state.Parameter{ + Name: src, + In: state.NewQueryLocation(src), + Schema: &state.Schema{DataType: "string"}, + Predicates: []*extension.PredicateConfig{ + { + Name: vp.Name, + Ensure: vp.Ensure, + Args: append([]string{}, vp.Arguments...), + }, + }, + }, + }) + declared[strings.ToLower(src)] = true + } + } + return result +} + +func cloneTypeContext(input *typectx.Context) *typectx.Context { + if input == nil { + return nil + } + ret := &typectx.Context{ + DefaultPackage: strings.TrimSpace(input.DefaultPackage), + PackageDir: strings.TrimSpace(input.PackageDir), + PackageName: strings.TrimSpace(input.PackageName), + PackagePath: strings.TrimSpace(input.PackagePath), + } + for _, item := range input.Imports { + pkg := strings.TrimSpace(item.Package) + if pkg == "" { + continue + } + ret.Imports = append(ret.Imports, typectx.Import{ + Alias: strings.TrimSpace(item.Alias), + Package: pkg, + }) + } + if ret.DefaultPackage == "" && + len(ret.Imports) == 0 && + ret.PackageDir == "" && + ret.PackageName == "" && + ret.PackagePath == "" { + return nil + } + return ret +} + +func cloneDirectives(input *dqlshape.Directives) *dqlshape.Directives { + if input == nil { + return nil + } + ret := &dqlshape.Directives{ + Meta: strings.TrimSpace(input.Meta), + DefaultConnector: strings.TrimSpace(input.DefaultConnector), + } + if input.Cache != nil { + ret.Cache = &dqlshape.CacheDirective{ + Enabled: input.Cache.Enabled, + TTL: strings.TrimSpace(input.Cache.TTL), + } + } + if input.MCP != nil { + ret.MCP = &dqlshape.MCPDirective{ + Name: strings.TrimSpace(input.MCP.Name), + Description: strings.TrimSpace(input.MCP.Description), + DescriptionPath: strings.TrimSpace(input.MCP.DescriptionPath), + } + } + if ret.Meta == "" && ret.DefaultConnector == "" && ret.Cache == nil && ret.MCP == nil { + return nil + } + return ret +} + +func pickRootView(views []*plan.View) *plan.View { + var selected *plan.View + minDepth := -1 + for _, candidate := range views { + if candidate == nil || candidate.Path == "" { + continue + } + depth := strings.Count(candidate.Path, ".") + if minDepth == -1 || depth < minDepth { + minDepth = depth + selected = candidate + } + } + if selected != nil { + return selected + } + for _, candidate := range views { + if candidate != nil { + return candidate + } + } + return nil +} + +func materializeView(item *plan.View) (*view.View, error) { + if item == nil { + return nil, fmt.Errorf("shape load: nil view plan item") + } + + schemaType := bestSchemaType(item) + if schemaType == nil { + return nil, fmt.Errorf("shape load: missing schema type for view %q", item.Name) + } + + schema := newSchema(schemaType, item.Cardinality) + mode := view.ModeQuery + switch strings.TrimSpace(item.Mode) { + case string(view.ModeExec): + mode = view.ModeExec + case string(view.ModeHandler): + mode = view.ModeHandler + case string(view.ModeQuery): + mode = view.ModeQuery + } + opts := []view.Option{view.WithSchema(schema), view.WithMode(mode)} + + if item.Connector != "" { + opts = append(opts, view.WithConnectorRef(item.Connector)) + } + if item.SQL != "" || item.SQLURI != "" { + tmpl := view.NewTemplate(item.SQL) + tmpl.SourceURL = item.SQLURI + if strings.TrimSpace(item.Summary) != "" { + tmpl.Summary = &view.TemplateSummary{ + Name: "Summary", + Source: item.Summary, + Kind: view.MetaKindRecord, + } + } + opts = append(opts, view.WithTemplate(tmpl)) + } + if item.CacheRef != "" { + opts = append(opts, view.WithCache(&view.Cache{Reference: shared.Reference{Ref: item.CacheRef}})) + } + if item.Partitioner != "" { + opts = append(opts, view.WithPartitioned(&view.Partitioned{ + DataType: item.Partitioner, + Concurrency: item.PartitionedConcurrency, + })) + } + + aView, err := view.New(item.Name, item.Table, opts...) + if err != nil { + return nil, err + } + aView.Ref = item.Ref + aView.Module = item.Module + aView.AllowNulls = item.AllowNulls + // Gap 6: forward view-level tag from declaration. + if item.Declaration != nil && strings.TrimSpace(item.Declaration.Tag) != "" { + aView.Tag = strings.TrimSpace(item.Declaration.Tag) + } + if strings.TrimSpace(item.SelectorNamespace) != "" || item.SelectorNoLimit != nil { + if aView.Selector == nil { + aView.Selector = &view.Config{} + } + if strings.TrimSpace(item.SelectorNamespace) != "" { + aView.Selector.Namespace = strings.TrimSpace(item.SelectorNamespace) + } + if item.SelectorNoLimit != nil { + aView.Selector.NoLimit = *item.SelectorNoLimit + } + } + if aView.Schema != nil && strings.TrimSpace(item.SchemaType) != "" { + if aView.Schema.DataType == "" { + aView.Schema.DataType = strings.TrimSpace(item.SchemaType) + } + if aView.Schema.Name == "" { + aView.Schema.Name = strings.Trim(strings.TrimSpace(item.SchemaType), "*") + } + } + // Populate columns from statically-inferred struct type so that xgen can + // generate accurate Go struct definitions during bootstrap. Only applied when + // the view has no columns yet (avoids overwriting explicit column config). + if len(aView.Columns) == 0 { + if cols := inferColumnsFromType(item.ElementType); len(cols) > 0 { + aView.Columns = cols + } + } + return aView, nil +} + +func bestSchemaType(item *plan.View) reflect.Type { + if item.FieldType != nil { + return item.FieldType + } + if item.ElementType != nil { + return item.ElementType + } + return nil +} + +func toViewRelations(input []*plan.Relation) []*view.Relation { + if len(input) == 0 { + return nil + } + result := make([]*view.Relation, 0, len(input)) + for _, item := range input { + if item == nil { + continue + } + relation := &view.Relation{ + Name: item.Name, + Holder: item.Holder, + On: toViewLinks(item.On, true), + Of: view.NewReferenceView( + toViewLinks(item.On, false), + view.NewView(item.Ref, item.Table), + ), + } + result = append(result, relation) + } + return result +} + +func toViewLinks(input []*plan.RelationLink, parent bool) view.Links { + if len(input) == 0 { + return nil + } + result := make(view.Links, 0, len(input)) + for _, item := range input { + if item == nil { + continue + } + link := &view.Link{} + if parent { + link.Field = item.ParentField + link.Namespace = item.ParentNamespace + link.Column = item.ParentColumn + } else { + link.Field = item.RefField + link.Namespace = item.RefNamespace + link.Column = item.RefColumn + } + result = append(result, link) + } + return result +} + +func newSchema(rType reflect.Type, cardinality string) *state.Schema { + if cardinality == "many" && rType.Kind() != reflect.Slice { + return state.NewSchema(rType, state.WithMany()) + } + return state.NewSchema(rType) +} diff --git a/repository/shape/load/loader_test.go b/repository/shape/load/loader_test.go new file mode 100644 index 000000000..20117b4e9 --- /dev/null +++ b/repository/shape/load/loader_test.go @@ -0,0 +1,244 @@ +package load + +import ( + "context" + "embed" + "reflect" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/viant/datly/repository/shape" + dqlshape "github.com/viant/datly/repository/shape/dql/shape" + "github.com/viant/datly/repository/shape/plan" + "github.com/viant/datly/repository/shape/scan" + "github.com/viant/datly/repository/shape/typectx" +) + +//go:embed testdata/*.sql +var testFS embed.FS + +type embeddedFS struct{} + +func (embeddedFS) EmbedFS() *embed.FS { + return &testFS +} + +type reportRow struct { + ID int + Name string +} + +type reportSource struct { + embeddedFS + Rows []reportRow `view:"rows,table=REPORT,connector=dev,cache=c1" sql:"uri=testdata/report.sql"` + ID int `parameter:"id,kind=query,in=id"` + Status any `parameter:"status,kind=output,in=status"` + Job any `parameter:"job,kind=async,in=job"` + Meta any `parameter:"meta,kind=meta,in=view.name"` +} + +func TestLoader_LoadViews(t *testing.T) { + scanner := scan.New() + scanned, err := scanner.Scan(context.Background(), &shape.Source{Struct: &reportSource{}}) + require.NoError(t, err) + + planner := plan.New() + planned, err := planner.Plan(context.Background(), scanned) + require.NoError(t, err) + + loader := New() + artifacts, err := loader.LoadViews(context.Background(), planned) + require.NoError(t, err) + require.NotNil(t, artifacts) + require.NotNil(t, artifacts.Resource) + require.Len(t, artifacts.Views, 1) + + aView := artifacts.Views[0] + assert.Equal(t, "rows", aView.Name) + assert.Equal(t, "REPORT", aView.Table) + require.NotNil(t, aView.Schema) + assert.Equal(t, "Many", string(aView.Schema.Cardinality)) + require.NotNil(t, aView.Template) + assert.Equal(t, "testdata/report.sql", aView.Template.SourceURL) + assert.Contains(t, aView.Template.Source, "SELECT ID, NAME FROM REPORT") + require.NotNil(t, aView.Connector) + assert.Equal(t, "dev", aView.Connector.Ref) + require.NotNil(t, aView.Cache) + assert.Equal(t, "c1", aView.Cache.Ref) + require.NotNil(t, artifacts.Resource.EmbedFS()) +} + +// stubPlanSpec is a non-plan-Result implementation of shape.PlanSpec used to +// verify that LoadViews() returns an error when given an unexpected plan type. +type stubPlanSpec struct{} + +func (s *stubPlanSpec) ShapeSpecKind() string { return "stub" } + +func TestLoader_LoadViews_InvalidPlanType(t *testing.T) { + loader := New() + _, err := loader.LoadViews(context.Background(), &shape.PlanResult{Source: &shape.Source{Name: "x"}, Plan: &stubPlanSpec{}}) + require.Error(t, err) + assert.Contains(t, err.Error(), "unsupported plan kind") +} + +func TestLoader_LoadViews_Metadata(t *testing.T) { + noLimit := true + allowNulls := true + planned := &shape.PlanResult{ + Source: &shape.Source{Name: "meta"}, + Plan: &plan.Result{ + Views: []*plan.View{ + { + Name: "items", + Table: "ITEMS", + Module: "platform/items", + AllowNulls: &allowNulls, + SelectorNamespace: "it", + SelectorNoLimit: &noLimit, + SchemaType: "*ItemView", + Cardinality: "many", + FieldType: reflect.TypeOf([]map[string]interface{}{}), + ElementType: reflect.TypeOf(map[string]interface{}{}), + SQL: "SELECT * FROM ITEMS", + }, + }, + ViewsByName: map[string]*plan.View{}, + ByPath: map[string]*plan.Field{}, + }, + } + loader := New() + artifacts, err := loader.LoadViews(context.Background(), planned) + require.NoError(t, err) + require.NotNil(t, artifacts) + require.Len(t, artifacts.Views, 1) + actual := artifacts.Views[0] + assert.Equal(t, "platform/items", actual.Module) + require.NotNil(t, actual.AllowNulls) + assert.True(t, *actual.AllowNulls) + require.NotNil(t, actual.Selector) + assert.Equal(t, "it", actual.Selector.Namespace) + assert.True(t, actual.Selector.NoLimit) + require.NotNil(t, actual.Schema) + assert.Equal(t, "*ItemView", actual.Schema.DataType) +} + +func TestLoader_LoadComponent(t *testing.T) { + scanner := scan.New() + scanned, err := scanner.Scan(context.Background(), &shape.Source{Name: "/v1/api/report", Struct: &reportSource{}}) + require.NoError(t, err) + + planner := plan.New() + planned, err := planner.Plan(context.Background(), scanned) + require.NoError(t, err) + actualPlan, ok := plan.ResultFrom(planned) + require.True(t, ok) + actualPlan.ColumnsDiscovery = true + actualPlan.TypeContext = &typectx.Context{ + DefaultPackage: "mdp/performance", + Imports: []typectx.Import{ + {Alias: "perf", Package: "github.com/acme/mdp/performance"}, + }, + } + actualPlan.Directives = &dqlshape.Directives{ + Meta: "docs/report.md", + DefaultConnector: "analytics", + Cache: &dqlshape.CacheDirective{ + Enabled: true, + TTL: "5m", + }, + MCP: &dqlshape.MCPDirective{ + Name: "report.list", + Description: "List report rows", + DescriptionPath: "docs/mcp/report.md", + }, + } + + loader := New() + artifact, err := loader.LoadComponent(context.Background(), planned) + require.NoError(t, err) + require.NotNil(t, artifact) + require.NotNil(t, artifact.Resource) + require.NotNil(t, artifact.Component) + + component, ok := ComponentFrom(artifact) + require.True(t, ok) + assert.Equal(t, "/v1/api/report", component.Name) + assert.Equal(t, "/v1/api/report", component.URI) + assert.Equal(t, "GET", component.Method) + assert.Equal(t, "rows", component.RootView) + assert.Equal(t, []string{"rows"}, component.Views) + assert.Len(t, component.Input, 1) + assert.Len(t, component.Output, 1) + assert.Len(t, component.Async, 1) + assert.Len(t, component.Meta, 1) + require.NotNil(t, component.TypeContext) + assert.Equal(t, "mdp/performance", component.TypeContext.DefaultPackage) + require.Len(t, component.TypeContext.Imports, 1) + assert.Equal(t, "perf", component.TypeContext.Imports[0].Alias) + require.NotNil(t, component.Directives) + assert.Equal(t, "docs/report.md", component.Directives.Meta) + assert.Equal(t, "analytics", component.Directives.DefaultConnector) + require.NotNil(t, component.Directives.Cache) + assert.True(t, component.Directives.Cache.Enabled) + assert.Equal(t, "5m", component.Directives.Cache.TTL) + require.NotNil(t, component.Directives.MCP) + assert.Equal(t, "report.list", component.Directives.MCP.Name) + assert.True(t, component.ColumnsDiscovery) +} + +func TestLoader_LoadComponent_RelationFieldsPreserved(t *testing.T) { + planned := &shape.PlanResult{ + Source: &shape.Source{Name: "/v1/api/report"}, + Plan: &plan.Result{ + Views: []*plan.View{ + { + Path: "Rows", + Name: "rows", + Table: "REPORT", + Cardinality: "many", + FieldType: reflect.TypeOf([]reportRow{}), + ElementType: reflect.TypeOf(reportRow{}), + Relations: []*plan.Relation{ + { + Name: "detail", + Holder: "Detail", + Ref: "detail", + Table: "REPORT_DETAIL", + On: []*plan.RelationLink{ + { + ParentField: "ReportID", + ParentNamespace: "rows", + ParentColumn: "REPORT_ID", + RefField: "ID", + RefNamespace: "detail", + RefColumn: "ID", + }, + }, + }, + }, + }, + }, + ViewsByName: map[string]*plan.View{}, + ByPath: map[string]*plan.Field{}, + }, + } + + loader := New() + artifact, err := loader.LoadComponent(context.Background(), planned) + require.NoError(t, err) + component, ok := ComponentFrom(artifact) + require.True(t, ok) + require.Len(t, component.ViewRelations, 1) + require.Len(t, component.ViewRelations[0].On, 1) + require.Len(t, component.ViewRelations[0].Of.On, 1) + + parent := component.ViewRelations[0].On[0] + ref := component.ViewRelations[0].Of.On[0] + assert.Equal(t, "ReportID", parent.Field) + assert.Equal(t, "rows", parent.Namespace) + assert.Equal(t, "REPORT_ID", parent.Column) + assert.Equal(t, "ID", ref.Field) + assert.Equal(t, "detail", ref.Namespace) + assert.Equal(t, "ID", ref.Column) +} diff --git a/repository/shape/load/model.go b/repository/shape/load/model.go new file mode 100644 index 000000000..a05f2287d --- /dev/null +++ b/repository/shape/load/model.go @@ -0,0 +1,46 @@ +package load + +import ( + "github.com/viant/datly/repository/shape" + dqlshape "github.com/viant/datly/repository/shape/dql/shape" + "github.com/viant/datly/repository/shape/plan" + "github.com/viant/datly/repository/shape/typectx" + "github.com/viant/datly/view" +) + +// Component is a shape-loaded runtime-neutral component artifact. +// It intentionally avoids repository package coupling to keep shape/load reusable. +type Component struct { + Name string + URI string + Method string + RootView string + Views []string + Relations []*plan.Relation + ViewRelations []*view.Relation + Declarations map[string]*plan.ViewDeclaration + QuerySelectors map[string][]string + Predicates map[string][]*plan.ViewPredicate + TypeContext *typectx.Context + Directives *dqlshape.Directives + ColumnsDiscovery bool + + Input []*plan.State + Output []*plan.State + Meta []*plan.State + Async []*plan.State + Other []*plan.State +} + +// ShapeSpecKind implements shape.ComponentSpec. +func (c *Component) ShapeSpecKind() string { return "component" } + +// ComponentFrom extracts the typed component from a ComponentArtifact. +// Returns (nil, false) when a is nil or contains an unexpected concrete type. +func ComponentFrom(a *shape.ComponentArtifact) (*Component, bool) { + if a == nil { + return nil, false + } + c, ok := a.Component.(*Component) + return c, ok && c != nil +} diff --git a/repository/shape/load/testdata/report.sql b/repository/shape/load/testdata/report.sql new file mode 100644 index 000000000..68f0f3b34 --- /dev/null +++ b/repository/shape/load/testdata/report.sql @@ -0,0 +1 @@ +SELECT ID, NAME FROM REPORT diff --git a/repository/shape/model.go b/repository/shape/model.go new file mode 100644 index 000000000..88c8da537 --- /dev/null +++ b/repository/shape/model.go @@ -0,0 +1,74 @@ +package shape + +import ( + "reflect" + + "github.com/viant/datly/view" + "github.com/viant/x" +) + +// Mode controls which execution flow is expected from the shape pipeline. +type Mode string + +const ( + ModeUnspecified Mode = "" + ModeStruct Mode = "struct" + ModeDQL Mode = "dql" +) + +// Source represents the caller-provided shape source. +type Source struct { + Name string + Path string + Connector string + Struct any + Type reflect.Type + TypeName string + TypeRegistry *x.Registry + DQL string +} + +// ScanSpec is implemented by every scan-pipeline descriptor result. +// The sole production implementation is *scan.Result. +type ScanSpec interface { + // ShapeSpecKind returns a diagnostic label used in error messages. + ShapeSpecKind() string +} + +// PlanSpec is implemented by every plan-pipeline result. +// The sole production implementation is *plan.Result. +type PlanSpec interface { + // ShapeSpecKind returns a diagnostic label used in error messages. + ShapeSpecKind() string +} + +// ComponentSpec is implemented by every component loader result. +// The sole production implementation is *load.Component. +type ComponentSpec interface { + // ShapeSpecKind returns a diagnostic label used in error messages. + ShapeSpecKind() string +} + +// ScanResult is the output produced by Scanner. +type ScanResult struct { + Source *Source + Descriptors ScanSpec +} + +// PlanResult is the output produced by Planner. +type PlanResult struct { + Source *Source + Plan PlanSpec +} + +// ViewArtifacts is the runtime view payload produced by Loader. +type ViewArtifacts struct { + Resource *view.Resource + Views view.Views +} + +// ComponentArtifact is the runtime component payload produced by Loader. +type ComponentArtifact struct { + Resource *view.Resource + Component ComponentSpec +} diff --git a/repository/shape/normalize/sql.go b/repository/shape/normalize/sql.go new file mode 100644 index 000000000..945840dd9 --- /dev/null +++ b/repository/shape/normalize/sql.go @@ -0,0 +1,56 @@ +package normalize + +import ( + "github.com/viant/sqlparser" + "github.com/viant/sqlparser/expr" + "github.com/viant/sqlparser/node" + "github.com/viant/sqlparser/query" + "github.com/viant/tagly/format/text" +) + +type mapper map[string]string + +func (m mapper) Map(name string) string { + ret, ok := m[name] + if ok { + return ret + } + return name +} + +func SQL(input string, generated bool, option func() sqlparser.Option) string { + if !generated { + return input + } + sqlQuery, err := sqlparser.ParseQuery(input, option()) + if err != nil { + return input + } + ns := mapper{} + if sqlQuery.From.Alias != "" { + ns[sqlQuery.From.Alias] = normalizeName(sqlQuery.From.Alias) + } + for _, join := range sqlQuery.Joins { + ns[join.Alias] = normalizeName(join.Alias) + } + + sqlparser.Traverse(sqlQuery, func(n node.Node) bool { + switch actual := n.(type) { + case *expr.Selector: + actual.Name = ns.Map(actual.Name) + case *query.Join: + actual.Alias = ns.Map(actual.Alias) + case *query.Item: + actual.Alias = ns.Map(actual.Alias) + case *query.From: + actual.Alias = ns.Map(actual.Alias) + } + return true + }) + return sqlparser.Stringify(sqlQuery) +} + +func normalizeName(k string) string { + caseFormat := text.DetectCaseFormat(k) + return caseFormat.Format(k, text.CaseFormatUpperCamel) +} diff --git a/repository/shape/normalize/sql_test.go b/repository/shape/normalize/sql_test.go new file mode 100644 index 000000000..aaba9af7d --- /dev/null +++ b/repository/shape/normalize/sql_test.go @@ -0,0 +1,66 @@ +package normalize + +import ( + "testing" + + "github.com/stretchr/testify/require" + legacy "github.com/viant/datly/cmd/options" + "github.com/viant/sqlparser" +) + +func parserOption() sqlparser.Option { + return sqlparser.WithErrorHandler(nil) +} + +func TestSQL_ParityWithLegacyNormalizer(t *testing.T) { + type normalizeCase struct { + Name string + Generated bool + SQL string + } + cases := []normalizeCase{ + { + Name: "skip normalization when not generated", + Generated: false, + SQL: "SELECT a.id FROM users a JOIN orders b ON a.id = b.user_id", + }, + { + Name: "invalid sql returns input", + Generated: true, + SQL: "SELECT * FROM (", + }, + { + Name: "normalize from and join aliases in selectors and alias nodes", + Generated: true, + SQL: "SELECT a.id, b.user_id FROM users a JOIN orders b ON a.id = b.user_id", + }, + { + Name: "keep alias that is already normalized", + Generated: true, + SQL: "SELECT UserAlias.id FROM users UserAlias", + }, + { + Name: "normalize snake_case alias", + Generated: true, + SQL: "SELECT order_item.id FROM users order_item", + }, + } + for _, testCase := range cases { + t.Run(testCase.Name, func(t *testing.T) { + expected := (&legacy.Rule{Generated: testCase.Generated}).NormalizeSQL(testCase.SQL, parserOption) + actual := SQL(testCase.SQL, testCase.Generated, parserOption) + require.Equal(t, expected, actual) + }) + } +} + +func TestMapper_Map(t *testing.T) { + m := mapper{"a": "A"} + require.Equal(t, "A", m.Map("a")) + require.Equal(t, "b", m.Map("b")) +} + +func TestNormalizeName(t *testing.T) { + require.Equal(t, "UserAlias", normalizeName("user_alias")) + require.Equal(t, "UserAlias", normalizeName("UserAlias")) +} diff --git a/repository/shape/options.go b/repository/shape/options.go new file mode 100644 index 000000000..27b970fae --- /dev/null +++ b/repository/shape/options.go @@ -0,0 +1,240 @@ +package shape + +// Options stores shape facade dependencies and behavior flags. +type Options struct { + Mode Mode + Strict bool + Name string + Scanner Scanner + Planner Planner + Loader Loader + Compiler DQLCompiler + Runtime RuntimeRegistrar + CompileProfile CompileProfile + CompileMixedMode CompileMixedMode + UnknownNonReadMode CompileUnknownNonReadMode + ColumnDiscoveryMode CompileColumnDiscoveryMode +} + +// Option mutates Options. +type Option func(*Options) + +// NewOptions builds Options from varargs. +func NewOptions(opts ...Option) *Options { + ret := &Options{ + CompileProfile: CompileProfileCompat, + CompileMixedMode: CompileMixedModeExecWins, + UnknownNonReadMode: CompileUnknownNonReadWarn, + ColumnDiscoveryMode: CompileColumnDiscoveryAuto, + } + for _, opt := range opts { + opt(ret) + } + return ret +} + +func WithMode(mode Mode) Option { + return func(o *Options) { + o.Mode = mode + } +} + +func WithStrict(strict bool) Option { + return func(o *Options) { + o.Strict = strict + } +} + +func WithName(name string) Option { + return func(o *Options) { + o.Name = name + } +} + +func WithScanner(scanner Scanner) Option { + return func(o *Options) { + o.Scanner = scanner + } +} + +func WithPlanner(planner Planner) Option { + return func(o *Options) { + o.Planner = planner + } +} + +func WithLoader(loader Loader) Option { + return func(o *Options) { + o.Loader = loader + } +} + +func WithCompiler(compiler DQLCompiler) Option { + return func(o *Options) { + o.Compiler = compiler + } +} + +func WithRuntime(runtime RuntimeRegistrar) Option { + return func(o *Options) { + o.Runtime = runtime + } +} + +// WithCompileProfileDefault sets default compiler profile used by Engine DQL compile path. +func WithCompileProfileDefault(profile CompileProfile) Option { + return func(o *Options) { + o.CompileProfile = profile + } +} + +// WithMixedModeDefault sets default compiler mixed read/exec mode used by Engine DQL compile path. +func WithMixedModeDefault(mode CompileMixedMode) Option { + return func(o *Options) { + o.CompileMixedMode = mode + } +} + +// WithUnknownNonReadModeDefault sets default unknown non-read mode used by Engine DQL compile path. +func WithUnknownNonReadModeDefault(mode CompileUnknownNonReadMode) Option { + return func(o *Options) { + o.UnknownNonReadMode = mode + } +} + +// WithColumnDiscoveryModeDefault sets default column discovery policy used by Engine DQL compile path. +func WithColumnDiscoveryModeDefault(mode CompileColumnDiscoveryMode) Option { + return func(o *Options) { + o.ColumnDiscoveryMode = mode + } +} + +// WithLegacyTranslatorDefaults configures Engine compile defaults to legacy-compatible behavior. +func WithLegacyTranslatorDefaults() Option { + return func(o *Options) { + o.Strict = false + o.CompileProfile = CompileProfileCompat + o.CompileMixedMode = CompileMixedModeExecWins + o.UnknownNonReadMode = CompileUnknownNonReadWarn + o.ColumnDiscoveryMode = CompileColumnDiscoveryAuto + } +} + +func WithCompileStrict(strict bool) CompileOption { + return func(o *CompileOptions) { + if o == nil { + return + } + o.Strict = strict + } +} + +func WithMixedMode(mode CompileMixedMode) CompileOption { + return func(o *CompileOptions) { + if o == nil { + return + } + o.MixedMode = mode + } +} + +func WithUnknownNonReadMode(mode CompileUnknownNonReadMode) CompileOption { + return func(o *CompileOptions) { + if o == nil { + return + } + o.UnknownNonReadMode = mode + } +} + +func WithCompileProfile(profile CompileProfile) CompileOption { + return func(o *CompileOptions) { + if o == nil { + return + } + o.Profile = profile + } +} + +func WithColumnDiscoveryMode(mode CompileColumnDiscoveryMode) CompileOption { + return func(o *CompileOptions) { + if o == nil { + return + } + o.ColumnDiscoveryMode = mode + } +} + +// WithDQLPathMarker overrides the path marker used to locate platform root from source path. +// Default is "/dql/". +func WithDQLPathMarker(marker string) CompileOption { + return func(o *CompileOptions) { + if o == nil { + return + } + o.DQLPathMarker = marker + } +} + +// WithRoutesRelativePath overrides routes path relative to detected platform root. +// Default is "repo/dev/Datly/routes". +func WithRoutesRelativePath(path string) CompileOption { + return func(o *CompileOptions) { + if o == nil { + return + } + o.RoutesRelativePath = path + } +} + +// WithTypeContextPackageDir sets default type-context package directory (for xgen parity). +func WithTypeContextPackageDir(dir string) CompileOption { + return func(o *CompileOptions) { + if o == nil { + return + } + o.TypePackageDir = dir + } +} + +// WithTypeContextPackageName sets default type-context package name (for xgen parity). +func WithTypeContextPackageName(name string) CompileOption { + return func(o *CompileOptions) { + if o == nil { + return + } + o.TypePackageName = name + } +} + +// WithTypeContextPackagePath sets default type-context package import path (for xgen parity). +func WithTypeContextPackagePath(path string) CompileOption { + return func(o *CompileOptions) { + if o == nil { + return + } + o.TypePackagePath = path + } +} + +// WithTypeContextPackageDefaults sets package dir/name/path in one call. +func WithTypeContextPackageDefaults(dir, name, path string) CompileOption { + return func(o *CompileOptions) { + if o == nil { + return + } + o.TypePackageDir = dir + o.TypePackageName = name + o.TypePackagePath = path + } +} + +// WithInferTypeContextDefaults enables/disables source-path based type context defaults. +func WithInferTypeContextDefaults(enabled bool) CompileOption { + return func(o *CompileOptions) { + if o == nil { + return + } + o.InferTypeContext = &enabled + } +} diff --git a/repository/shape/parity_test.go b/repository/shape/parity_test.go new file mode 100644 index 000000000..8041328b1 --- /dev/null +++ b/repository/shape/parity_test.go @@ -0,0 +1,108 @@ +package shape_test + +import ( + "context" + "embed" + "reflect" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + shape "github.com/viant/datly/repository/shape" + shapeLoad "github.com/viant/datly/repository/shape/load" + shapePlan "github.com/viant/datly/repository/shape/plan" + shapeScan "github.com/viant/datly/repository/shape/scan" +) + +//go:embed scan/testdata/*.sql +var parityFS embed.FS + +type parityEmbedded struct{} + +func (parityEmbedded) EmbedFS() *embed.FS { return &parityFS } + +type parityRow struct { + ID int + Name string +} + +type paritySource struct { + parityEmbedded + Rows []parityRow `view:"rows,table=REPORT,connector=dev" sql:"uri=scan/testdata/report.sql"` +} + +type parityJoinRow struct { + ReportID int `source:"REPORT_ID"` +} + +type parityJoinSource struct { + parityEmbedded + Rows []parityJoinRow `view:"rows,table=REPORT,connector=dev" sql:"uri=scan/testdata/report.sql" on:"ReportID:rows.REPORT_ID=ID:detail.ID"` +} + +func TestEngineParity_StructPipeline(t *testing.T) { + source := &paritySource{} + scanner := shapeScan.New() + planner := shapePlan.New() + loader := shapeLoad.New() + + manualScan, err := scanner.Scan(context.Background(), &shape.Source{Name: "/v1/api/parity", Struct: source}) + require.NoError(t, err) + manualPlan, err := planner.Plan(context.Background(), manualScan) + require.NoError(t, err) + manualViews, err := loader.LoadViews(context.Background(), manualPlan) + require.NoError(t, err) + + engine := shape.New( + shape.WithName("/v1/api/parity"), + shape.WithScanner(scanner), + shape.WithPlanner(planner), + shape.WithLoader(loader), + ) + engineViews, err := engine.LoadViews(context.Background(), source) + require.NoError(t, err) + + require.Len(t, manualViews.Views, 1) + require.Len(t, engineViews.Views, 1) + + mv := manualViews.Views[0] + ev := engineViews.Views[0] + assert.Equal(t, mv.Name, ev.Name) + assert.Equal(t, mv.Table, ev.Table) + assert.Equal(t, mv.Template.Source, ev.Template.Source) + assert.Equal(t, mv.Template.SourceURL, ev.Template.SourceURL) + assert.Equal(t, mv.Schema.Cardinality, ev.Schema.Cardinality) + assert.Equal(t, reflect.TypeOf(mv.Schema.CompType()), reflect.TypeOf(ev.Schema.CompType())) +} + +func TestEngineParity_Component_SourceTagFieldJoin(t *testing.T) { + source := &parityJoinSource{} + scanner := shapeScan.New() + planner := shapePlan.New() + loader := shapeLoad.New() + + engine := shape.New( + shape.WithName("/v1/api/parity"), + shape.WithScanner(scanner), + shape.WithPlanner(planner), + shape.WithLoader(loader), + ) + artifact, err := engine.LoadComponent(context.Background(), source) + require.NoError(t, err) + require.NotNil(t, artifact) + + component, ok := shapeLoad.ComponentFrom(artifact) + require.True(t, ok) + require.Len(t, component.ViewRelations, 1) + require.Len(t, component.ViewRelations[0].On, 1) + require.Len(t, component.ViewRelations[0].Of.On, 1) + + parent := component.ViewRelations[0].On[0] + ref := component.ViewRelations[0].Of.On[0] + assert.Equal(t, "ReportID", parent.Field) + assert.Equal(t, "rows", parent.Namespace) + assert.Equal(t, "REPORT_ID", parent.Column) + assert.Equal(t, "ID", ref.Field) + assert.Equal(t, "detail", ref.Namespace) + assert.Equal(t, "ID", ref.Column) +} diff --git a/repository/shape/plan/doc.go b/repository/shape/plan/doc.go new file mode 100644 index 000000000..57bb65fae --- /dev/null +++ b/repository/shape/plan/doc.go @@ -0,0 +1,2 @@ +// Package plan defines normalization and shape-planning responsibilities. +package plan diff --git a/repository/shape/plan/model.go b/repository/shape/plan/model.go new file mode 100644 index 000000000..ffb295f28 --- /dev/null +++ b/repository/shape/plan/model.go @@ -0,0 +1,152 @@ +package plan + +import ( + "embed" + "reflect" + "strings" + + dqlshape "github.com/viant/datly/repository/shape/dql/shape" + "github.com/viant/datly/repository/shape/typectx" + "github.com/viant/datly/view/state" +) + +// Result is normalized shape plan produced from scan descriptors. +type Result struct { + RootType reflect.Type + EmbedFS *embed.FS + + Fields []*Field + ByPath map[string]*Field + Views []*View + ViewsByName map[string]*View + States []*State + Types []*Type + ColumnsDiscovery bool + TypeContext *typectx.Context + Directives *dqlshape.Directives + Diagnostics []*dqlshape.Diagnostic +} + +// Type is normalized type metadata collected during compile. +type Type struct { + Name string + Alias string + DataType string + Cardinality string + Package string + ModulePath string +} + +// Field is a normalized projection of scanned field metadata. +type Field struct { + Path string + Name string + Type reflect.Type + Index []int +} + +// View is a normalized view field plan. +type View struct { + Path string + Name string + Ref string + Mode string + Table string + Module string + Connector string + CacheRef string + Partitioner string + PartitionedConcurrency int + RelationalConcurrency int + SQL string + SQLURI string + Summary string + Relations []*Relation + Holder string + + AllowNulls *bool + SelectorNamespace string + SelectorNoLimit *bool + SchemaType string + ColumnsDiscovery bool + + Cardinality string + ElementType reflect.Type + FieldType reflect.Type + Declaration *ViewDeclaration +} + +// ViewDeclaration captures declaration options used to derive a view from DQL directives. +type ViewDeclaration struct { + Tag string + Codec string + CodecArgs []string + HandlerName string + HandlerArgs []string + StatusCode *int + ErrorMessage string + QuerySelector string + CacheRef string + Limit *int + Cacheable *bool + When string + Scope string + DataType string + Of string + Value string + Async bool + Output bool + Predicates []*ViewPredicate +} + +// ViewPredicate captures WithPredicate / EnsurePredicate metadata. +type ViewPredicate struct { + Name string + Source string + Ensure bool + Arguments []string +} + +// Relation is normalized relation metadata extracted from DQL joins. +type Relation struct { + Name string + Holder string + Ref string + Table string + Kind string + Raw string + On []*RelationLink + Warnings []string +} + +// RelationLink represents one parent/ref join predicate. +type RelationLink struct { + ParentField string + ParentNamespace string + ParentColumn string + RefField string + RefNamespace string + RefColumn string + Expression string +} + +// State is a normalized parameter field plan. +type State struct { + state.Parameter `yaml:",inline"` + QuerySelector string + OutputDataType string +} + +func (s *State) KindString() string { + if s == nil || s.In == nil { + return "" + } + return strings.TrimSpace(string(s.In.Kind)) +} + +func (s *State) InName() string { + if s == nil || s.In == nil { + return "" + } + return strings.TrimSpace(s.In.Name) +} diff --git a/repository/shape/plan/planner.go b/repository/shape/plan/planner.go new file mode 100644 index 000000000..f78b4bcc1 --- /dev/null +++ b/repository/shape/plan/planner.go @@ -0,0 +1,250 @@ +package plan + +import ( + "context" + "fmt" + "reflect" + "strings" + + "github.com/viant/datly/repository/locator/async/keys" + metakeys "github.com/viant/datly/repository/locator/meta/keys" + outputkeys "github.com/viant/datly/repository/locator/output/keys" + "github.com/viant/datly/repository/shape" + "github.com/viant/datly/repository/shape/scan" + "github.com/viant/datly/view/state" +) + +// Planner normalizes scan descriptors into shape plan. +type Planner struct{} + +// New returns shape planner implementation. +func New() *Planner { + return &Planner{} +} + +// Plan implements shape.Planner. +func (p *Planner) Plan(ctx context.Context, scanned *shape.ScanResult, _ ...shape.PlanOption) (*shape.PlanResult, error) { + if err := ctx.Err(); err != nil { + return nil, err + } + if scanned == nil || scanned.Source == nil { + return nil, shape.ErrNilSource + } + + scanResult, ok := scan.DescriptorsFrom(scanned) + if !ok { + return nil, fmt.Errorf("shape plan: unsupported descriptors kind %q", scanned.Descriptors.ShapeSpecKind()) + } + + result := &Result{ + RootType: scanResult.RootType, + EmbedFS: scanResult.EmbedFS, + ByPath: map[string]*Field{}, + ViewsByName: map[string]*View{}, + } + + for _, item := range scanResult.Fields { + field := &Field{ + Path: item.Path, + Name: item.Name, + Type: item.Type, + Index: append([]int(nil), item.Index...), + } + result.Fields = append(result.Fields, field) + result.ByPath[field.Path] = field + } + + for _, item := range scanResult.ViewFields { + v := normalizeView(item) + result.Views = append(result.Views, v) + if v.Name != "" { + result.ViewsByName[v.Name] = v + } + } + + for _, item := range scanResult.StateFields { + result.States = append(result.States, normalizeState(item)) + } + + return &shape.PlanResult{Source: scanned.Source, Plan: result}, nil +} + +func normalizeView(field *scan.Field) *View { + result := &View{ + Path: field.Path, + Holder: field.Name, + FieldType: field.Type, + } + + if tag := field.ViewTag; tag != nil { + if tag.View != nil { + result.Name = tag.View.Name + result.Table = tag.View.Table + result.Connector = tag.View.Connector + result.CacheRef = tag.View.Cache + result.Partitioner = tag.View.PartitionerType + result.PartitionedConcurrency = tag.View.PartitionedConcurrency + result.RelationalConcurrency = tag.View.RelationalConcurrency + } + result.SQL = tag.SQL.SQL + result.SQLURI = tag.SQL.URI + result.Summary = tag.SummarySQL.SQL + if len(tag.LinkOn) > 0 { + result.Relations = append(result.Relations, relationFromTagLinks(field.Name, tag.LinkOn)) + } + result.Ref = strings.TrimSpace(tag.TypeName) + } + + if result.Name == "" { + result.Name = field.Name + } + + elem, cardinality := componentType(field.Type) + result.Cardinality = cardinality + result.ElementType = elem + return result +} + +func relationFromTagLinks(holder string, links []string) *Relation { + relation := &Relation{ + Name: strings.TrimSpace(holder), + Holder: strings.TrimSpace(holder), + Ref: strings.TrimSpace(holder), + } + for _, linkExpr := range links { + linkExpr = strings.TrimSpace(linkExpr) + if linkExpr == "" { + continue + } + left, right, ok := strings.Cut(linkExpr, "=") + if !ok { + continue + } + leftField, leftNS, leftCol := splitTagSelector(left) + rightField, rightNS, rightCol := splitTagSelector(right) + if leftCol == "" || rightCol == "" { + continue + } + relation.On = append(relation.On, &RelationLink{ + ParentField: leftField, + ParentNamespace: leftNS, + ParentColumn: leftCol, + RefField: rightField, + RefNamespace: rightNS, + RefColumn: rightCol, + Expression: strings.TrimSpace(left) + "=" + strings.TrimSpace(right), + }) + } + if relation.Ref == "" { + relation.Ref = "relation" + } + if relation.Holder == "" { + relation.Holder = relation.Ref + } + if relation.Name == "" { + relation.Name = relation.Holder + } + return relation +} + +func splitTagSelector(value string) (string, string, string) { + value = strings.TrimSpace(value) + value = strings.TrimSuffix(value, "(true)") + value = strings.TrimSuffix(value, "(false)") + field := "" + if idx := strings.Index(value, ":"); idx >= 0 { + field = strings.TrimSpace(value[:idx]) + value = value[idx+1:] + } + value = strings.Trim(value, "`\"") + if value == "" { + return field, "", "" + } + if idx := strings.Index(value, "."); idx >= 0 { + return field, strings.TrimSpace(value[:idx]), strings.TrimSpace(value[idx+1:]) + } + return field, "", strings.TrimSpace(value) +} + +func normalizeState(field *scan.Field) *State { + result := &State{ + Parameter: state.Parameter{ + Name: field.Name, + In: &state.Location{}, + }, + } + if field.StateTag == nil || field.StateTag.Parameter == nil { + result.Schema = state.NewSchema(field.Type) + return result + } + + pTag := field.StateTag.Parameter + result.Name = firstNonEmpty(pTag.Name, field.Name) + result.In = &state.Location{ + Kind: state.Kind(strings.ToLower(strings.TrimSpace(pTag.Kind))), + Name: strings.TrimSpace(pTag.In), + } + result.When = pTag.When + result.Scope = pTag.Scope + result.Required = pTag.Required + result.Async = pTag.Async + result.Cacheable = pTag.Cacheable + result.With = pTag.With + result.URI = pTag.URI + result.ErrorStatusCode = pTag.ErrorCode + result.ErrorMessage = pTag.ErrorMessage + + result.Schema = state.NewSchema(resolveStateType(result, field.Type)) + if dataType := strings.TrimSpace(pTag.DataType); dataType != "" { + result.Schema.DataType = dataType + } + return result +} + +func resolveStateType(item *State, fallback reflect.Type) reflect.Type { + if item.In == nil { + return fallback + } + key := strings.ToLower(strings.TrimSpace(firstNonEmpty(item.In.Name, item.Name))) + switch item.In.Kind { + case state.KindOutput: + if rType, ok := outputkeys.Types[key]; ok { + return rType + } + case state.KindMeta: + if rType, ok := metakeys.Types[key]; ok { + return rType + } + case state.KindAsync: + if rType, ok := keys.Types[key]; ok { + return rType + } + } + return fallback +} + +func componentType(rType reflect.Type) (reflect.Type, string) { + if rType == nil { + return nil, "one" + } + for rType.Kind() == reflect.Ptr { + rType = rType.Elem() + } + if rType.Kind() == reflect.Slice { + elem := rType.Elem() + for elem.Kind() == reflect.Ptr { + elem = elem.Elem() + } + return elem, "many" + } + return rType, "one" +} + +func firstNonEmpty(values ...string) string { + for _, value := range values { + if trimmed := strings.TrimSpace(value); trimmed != "" { + return trimmed + } + } + return "" +} diff --git a/repository/shape/plan/planner_test.go b/repository/shape/plan/planner_test.go new file mode 100644 index 000000000..dc0416eb8 --- /dev/null +++ b/repository/shape/plan/planner_test.go @@ -0,0 +1,152 @@ +package plan + +import ( + "context" + "embed" + "strings" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + asynckeys "github.com/viant/datly/repository/locator/async/keys" + metakeys "github.com/viant/datly/repository/locator/meta/keys" + outputkeys "github.com/viant/datly/repository/locator/output/keys" + "github.com/viant/datly/repository/shape" + "github.com/viant/datly/repository/shape/scan" +) + +//go:embed testdata/*.sql +var testFS embed.FS + +type embeddedFS struct{} + +func (embeddedFS) EmbedFS() *embed.FS { + return &testFS +} + +type reportRow struct { + ID int +} + +type reportSource struct { + embeddedFS + Rows []reportRow `view:"rows,table=REPORT,connector=dev" sql:"uri=testdata/report.sql"` + Status interface{} `parameter:"status,kind=output,in=status"` + Job interface{} `parameter:"job,kind=async,in=job"` + VName interface{} `parameter:"viewName,kind=meta,in=view.name"` + ID int `parameter:"id,kind=query,in=id"` +} + +type relationRow struct { + ID int +} + +type relationSource struct { + Rows []relationRow `view:"rows,table=REPORT" on:"rows.report_id=report.id"` +} + +type relationSourceWithFields struct { + Rows []relationRow `view:"rows,table=REPORT" on:"ReportID:rows.report_id=ID:report.id"` +} + +func TestPlanner_Plan(t *testing.T) { + scanner := scan.New() + scanned, err := scanner.Scan(context.Background(), &shape.Source{Struct: &reportSource{}}) + require.NoError(t, err) + + planner := New() + planned, err := planner.Plan(context.Background(), scanned) + require.NoError(t, err) + require.NotNil(t, planned) + + result, ok := ResultFrom(planned) + require.True(t, ok) + require.NotNil(t, result) + require.NotNil(t, result.EmbedFS) + + require.Len(t, result.Views, 1) + rows := result.Views[0] + assert.Equal(t, "rows", rows.Name) + assert.Equal(t, "REPORT", rows.Table) + assert.Equal(t, "dev", rows.Connector) + assert.Equal(t, "many", rows.Cardinality) + assert.Equal(t, "Rows", rows.Holder) + assert.Contains(t, rows.SQL, "SELECT ID") + + stateByPath := map[string]*State{} + for _, item := range result.States { + stateByPath[strings.ToLower(item.Name)] = item + } + + require.NotNil(t, stateByPath["status"]) + assert.Equal(t, outputkeys.Types["status"], stateByPath["status"].Schema.Type()) + require.NotNil(t, stateByPath["job"]) + assert.Equal(t, asynckeys.Types["job"], stateByPath["job"].Schema.Type()) + require.NotNil(t, stateByPath["viewname"]) + assert.Equal(t, metakeys.Types["view.name"], stateByPath["viewname"].Schema.Type()) + + require.NotNil(t, stateByPath["id"]) + assert.Equal(t, "query", stateByPath["id"].KindString()) + assert.Equal(t, "id", stateByPath["id"].InName()) +} + +func TestPlanner_Plan_LinkOnProducesStructuredRelations(t *testing.T) { + scanner := scan.New() + scanned, err := scanner.Scan(context.Background(), &shape.Source{Struct: &relationSource{}}) + require.NoError(t, err) + + planner := New() + planned, err := planner.Plan(context.Background(), scanned) + require.NoError(t, err) + require.NotNil(t, planned) + + result, ok := ResultFrom(planned) + require.True(t, ok) + require.Len(t, result.Views, 1) + viewPlan := result.Views[0] + require.Len(t, viewPlan.Relations, 1) + relation := viewPlan.Relations[0] + require.Len(t, relation.On, 1) + assert.Equal(t, "rows", relation.On[0].ParentNamespace) + assert.Equal(t, "report_id", relation.On[0].ParentColumn) + assert.Equal(t, "report", relation.On[0].RefNamespace) + assert.Equal(t, "id", relation.On[0].RefColumn) +} + +func TestPlanner_Plan_LinkOnPreservesFieldSelectors(t *testing.T) { + scanner := scan.New() + scanned, err := scanner.Scan(context.Background(), &shape.Source{Struct: &relationSourceWithFields{}}) + require.NoError(t, err) + + planner := New() + planned, err := planner.Plan(context.Background(), scanned) + require.NoError(t, err) + require.NotNil(t, planned) + + result, ok := ResultFrom(planned) + require.True(t, ok) + require.Len(t, result.Views, 1) + viewPlan := result.Views[0] + require.Len(t, viewPlan.Relations, 1) + relation := viewPlan.Relations[0] + require.Len(t, relation.On, 1) + assert.Equal(t, "ReportID", relation.On[0].ParentField) + assert.Equal(t, "rows", relation.On[0].ParentNamespace) + assert.Equal(t, "report_id", relation.On[0].ParentColumn) + assert.Equal(t, "ID", relation.On[0].RefField) + assert.Equal(t, "report", relation.On[0].RefNamespace) + assert.Equal(t, "id", relation.On[0].RefColumn) +} + +// stubScanSpec is a non-scan-Result implementation of shape.ScanSpec used to +// verify that Plan() returns an error when given an unexpected descriptor type. +type stubScanSpec struct{} + +func (s *stubScanSpec) ShapeSpecKind() string { return "stub" } + +func TestPlanner_Plan_InvalidDescriptors(t *testing.T) { + planner := New() + _, err := planner.Plan(context.Background(), &shape.ScanResult{Source: &shape.Source{Name: "x"}, Descriptors: &stubScanSpec{}}) + require.Error(t, err) + assert.Contains(t, err.Error(), "unsupported descriptors kind") +} diff --git a/repository/shape/plan/spec.go b/repository/shape/plan/spec.go new file mode 100644 index 000000000..9003699d1 --- /dev/null +++ b/repository/shape/plan/spec.go @@ -0,0 +1,16 @@ +package plan + +import "github.com/viant/datly/repository/shape" + +// ShapeSpecKind implements shape.PlanSpec. +func (r *Result) ShapeSpecKind() string { return "plan" } + +// ResultFrom extracts the typed plan result from a PlanResult. +// Returns (nil, false) when a is nil or contains an unexpected concrete type. +func ResultFrom(a *shape.PlanResult) (*Result, bool) { + if a == nil { + return nil, false + } + r, ok := a.Plan.(*Result) + return r, ok && r != nil +} diff --git a/repository/shape/plan/testdata/report.sql b/repository/shape/plan/testdata/report.sql new file mode 100644 index 000000000..7aab3a1f8 --- /dev/null +++ b/repository/shape/plan/testdata/report.sql @@ -0,0 +1 @@ +SELECT ID FROM REPORT diff --git a/repository/shape/platform_parity_metadata_test.go b/repository/shape/platform_parity_metadata_test.go new file mode 100644 index 000000000..8ad4a6d8c --- /dev/null +++ b/repository/shape/platform_parity_metadata_test.go @@ -0,0 +1,77 @@ +package shape_test + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestCompareMetadataParity(t *testing.T) { + trueValue := true + falseValue := false + + legacyMeta := &resourceMetaIR{ColumnsDiscovery: &trueValue} + shapeMeta := &resourceMetaIR{ColumnsDiscovery: &trueValue} + + legacyViews := []viewMetaIR{ + { + Name: "items", + Mode: "SQLQuery", + Module: "platform/items", + AllowNulls: &trueValue, + SelectorNamespace: "item", + SelectorNoLimit: &falseValue, + SchemaCardinality: "Many", + SchemaType: "*ItemView", + HasSummary: &trueValue, + }, + } + shapeViews := []viewMetaIR{ + { + Name: "items", + Mode: "SQLQuery", + Module: "platform/items", + AllowNulls: &trueValue, + SelectorNamespace: "item", + SelectorNoLimit: &falseValue, + SchemaCardinality: "Many", + SchemaType: "*ItemView", + HasSummary: &trueValue, + }, + } + + assert.Empty(t, compareMetadataParity(legacyMeta, shapeMeta, legacyViews, shapeViews)) +} + +func TestCompareMetadataParity_DetectsMismatches(t *testing.T) { + trueValue := true + falseValue := false + + legacyMeta := &resourceMetaIR{ColumnsDiscovery: &trueValue} + shapeMeta := &resourceMetaIR{ColumnsDiscovery: &falseValue} + + legacyViews := []viewMetaIR{{ + Name: "items", + Mode: "SQLQuery", + Module: "platform/items", + AllowNulls: &trueValue, + SelectorNoLimit: &trueValue, + SchemaType: "*ItemView", + }} + shapeViews := []viewMetaIR{{ + Name: "items", + Mode: "SQLExec", + Module: "platform/items2", + AllowNulls: &falseValue, + SelectorNoLimit: &falseValue, + SchemaType: "*OtherView", + }} + + mismatches := compareMetadataParity(legacyMeta, shapeMeta, legacyViews, shapeViews) + assert.Contains(t, mismatches, "resource columnsDiscovery mismatch") + assert.Contains(t, mismatches, "view mode mismatch: items") + assert.Contains(t, mismatches, "view module mismatch: items") + assert.Contains(t, mismatches, "view allowNulls mismatch: items") + assert.Contains(t, mismatches, "view selector noLimit mismatch: items") + assert.Contains(t, mismatches, "view schema type mismatch: items") +} diff --git a/repository/shape/platform_parity_test.go b/repository/shape/platform_parity_test.go new file mode 100644 index 000000000..122a6b666 --- /dev/null +++ b/repository/shape/platform_parity_test.go @@ -0,0 +1,1482 @@ +package shape_test + +import ( + "context" + "fmt" + "os" + "path/filepath" + "regexp" + "sort" + "strconv" + "strings" + "testing" + + shape "github.com/viant/datly/repository/shape" + shapecompile "github.com/viant/datly/repository/shape/compile" + dqlshape "github.com/viant/datly/repository/shape/dql/shape" + dqlstmt "github.com/viant/datly/repository/shape/dql/statement" + shapeload "github.com/viant/datly/repository/shape/load" + "github.com/viant/datly/repository/shape/plan" + "github.com/viant/datly/view" + "gopkg.in/yaml.v3" +) + +type parityRule struct { + Mode string `yaml:"mode"` + Namespace string `yaml:"namespace"` + Source string `yaml:"source"` + Connector string `yaml:"connector,omitempty"` +} + +type legacyYAML struct { + ColumnsDiscovery *bool `yaml:"ColumnsDiscovery"` + TypeContext struct { + DefaultPackage string `yaml:"DefaultPackage"` + PackageDir string `yaml:"PackageDir"` + PackageName string `yaml:"PackageName"` + PackagePath string `yaml:"PackagePath"` + } `yaml:"TypeContext"` + Resource struct { + Views []struct { + Name string `yaml:"Name"` + Table string `yaml:"Table"` + Mode string `yaml:"Mode"` + Module string `yaml:"Module"` + AllowNulls *bool `yaml:"AllowNulls"` + Connector struct { + Ref string `yaml:"Ref"` + } `yaml:"Connector"` + Schema struct { + Cardinality string `yaml:"Cardinality"` + DataType string `yaml:"DataType"` + Name string `yaml:"Name"` + } `yaml:"Schema"` + Template struct { + SourceURL string `yaml:"SourceURL"` + Summary *struct { + Name string `yaml:"Name"` + Kind string `yaml:"Kind"` + } `yaml:"Summary"` + } `yaml:"Template"` + Selector struct { + Namespace string `yaml:"Namespace"` + NoLimit *bool `yaml:"NoLimit"` + LimitParameter selectorParam `yaml:"LimitParameter"` + OffsetParameter selectorParam `yaml:"OffsetParameter"` + PageParameter selectorParam `yaml:"PageParameter"` + FieldsParameter selectorParam `yaml:"FieldsParameter"` + OrderByParameter selectorParam `yaml:"OrderByParameter"` + } `yaml:"Selector"` + } `yaml:"Views"` + Parameters []struct { + Name string `yaml:"Name"` + URI string `yaml:"URI"` + Value string `yaml:"Value"` + Required *bool `yaml:"Required"` + Cacheable *bool `yaml:"Cacheable"` + In struct { + Kind string `yaml:"Kind"` + Name string `yaml:"Name"` + } `yaml:"In"` + Predicates []struct { + Group int `yaml:"Group"` + Name string `yaml:"Name"` + Ensure bool `yaml:"Ensure"` + Args []string `yaml:"Args"` + } `yaml:"Predicates"` + } `yaml:"Parameters"` + Types []struct { + Name string `yaml:"Name"` + Alias string `yaml:"Alias"` + DataType string `yaml:"DataType"` + Cardinality string `yaml:"Cardinality"` + Package string `yaml:"Package"` + ModulePath string `yaml:"ModulePath"` + } `yaml:"Types"` + } `yaml:"Resource"` + Routes []struct { + Method string `yaml:"Method"` + URI string `yaml:"URI"` + View struct { + Ref string `yaml:"Ref"` + } `yaml:"View"` + } `yaml:"Routes"` +} + +type viewIR struct { + Name string `yaml:"name"` + Table string `yaml:"table"` + Connector string `yaml:"connector,omitempty"` + SQLURI string `yaml:"sqlUri,omitempty"` +} + +type routeIR struct { + Method string `yaml:"method,omitempty"` + URI string `yaml:"uri,omitempty"` + View string `yaml:"view,omitempty"` +} + +type resourceMetaIR struct { + ColumnsDiscovery *bool `yaml:"columnsDiscovery,omitempty"` +} + +type viewMetaIR struct { + Name string `yaml:"name"` + Mode string `yaml:"mode,omitempty"` + Module string `yaml:"module,omitempty"` + AllowNulls *bool `yaml:"allowNulls,omitempty"` + SelectorNamespace string `yaml:"selectorNamespace,omitempty"` + SelectorNoLimit *bool `yaml:"selectorNoLimit,omitempty"` + SchemaCardinality string `yaml:"schemaCardinality,omitempty"` + SchemaType string `yaml:"schemaType,omitempty"` + HasSummary *bool `yaml:"hasSummary,omitempty"` +} + +type parityOutput struct { + Namespace string `yaml:"namespace"` + Source string `yaml:"source"` + LegacyYAML string `yaml:"legacyYaml"` + LegacyMeta *resourceMetaIR `yaml:"legacyMeta,omitempty"` + LegacyViews []viewIR `yaml:"legacyViews,omitempty"` + LegacyViewMeta []viewMetaIR `yaml:"legacyViewMeta,omitempty"` + LegacyParams []paramIR `yaml:"legacyParams,omitempty"` + LegacyRoutes []routeIR `yaml:"legacyRoutes,omitempty"` + LegacyTypes []typeIR `yaml:"legacyTypes,omitempty"` + LegacyTypeCtx *typeCtxIR `yaml:"legacyTypeContext,omitempty"` + ShapeMeta *resourceMetaIR `yaml:"shapeMeta,omitempty"` + ShapeViews []viewIR `yaml:"shapeViews,omitempty"` + ShapeViewMeta []viewMetaIR `yaml:"shapeViewMeta,omitempty"` + ShapeParams []paramIR `yaml:"shapeParams,omitempty"` + ShapeTypes []typeIR `yaml:"shapeTypes,omitempty"` + ShapeTypeCtx *typeCtxIR `yaml:"shapeTypeContext,omitempty"` + ShapeDiags []string `yaml:"shapeDiagnostics,omitempty"` + Mismatches []string `yaml:"mismatches,omitempty"` + CompileFailed bool `yaml:"compileFailed,omitempty"` + RawDiagnostics []*dqlshape.Diagnostic `yaml:"-"` +} + +type parityReport struct { + Total int `yaml:"total"` + Compared int `yaml:"compared"` + WithDiff int `yaml:"withDiff"` + MissingYAML int `yaml:"missingYaml"` + Failures int `yaml:"failures"` + TopIssues []string `yaml:"topIssues,omitempty"` +} + +type selectorParam struct { + Name string `yaml:"Name"` + Cacheable *bool `yaml:"Cacheable"` + In struct { + Kind string `yaml:"Kind"` + Name string `yaml:"Name"` + } `yaml:"In"` +} + +type paramIR struct { + Name string `yaml:"name"` + Kind string `yaml:"kind,omitempty"` + In string `yaml:"in,omitempty"` + Required *bool `yaml:"required,omitempty"` + Cacheable *bool `yaml:"cacheable,omitempty"` + URI string `yaml:"uri,omitempty"` + Value string `yaml:"value,omitempty"` + QuerySelector string `yaml:"querySelector,omitempty"` + Predicates []string `yaml:"predicates,omitempty"` +} + +type typeIR struct { + Name string `yaml:"name"` + Alias string `yaml:"alias,omitempty"` + DataType string `yaml:"dataType,omitempty"` + Cardinality string `yaml:"cardinality,omitempty"` + Package string `yaml:"package,omitempty"` + ModulePath string `yaml:"modulePath,omitempty"` +} + +type typeCtxIR struct { + DefaultPackage string `yaml:"defaultPackage,omitempty"` + PackageDir string `yaml:"packageDir,omitempty"` + PackageName string `yaml:"packageName,omitempty"` + PackagePath string `yaml:"packagePath,omitempty"` +} + +type parityEntryEval struct { + Output parityOutput + SourceReadable bool + MissingLegacyYAML bool +} + +func TestPlatform_DQLToRoute_ParityIR_SmokeHandlers(t *testing.T) { + if !strings.EqualFold(strings.TrimSpace(os.Getenv("PLATFORM_PARITY_SMOKE")), "1") { + t.Skip("set PLATFORM_PARITY_SMOKE=1 to run legacy parity smoke handlers") + } + platformRoot := os.Getenv("PLATFORM_ROOT") + if platformRoot == "" { + platformRoot = "/Users/awitas/go/src/github.vianttech.com/viant/platform" + } + rulesRoot := filepath.Join(platformRoot, "e2e", "rule") + routesRoot := filepath.Join(platformRoot, "repo", "dev", "Datly", "routes") + if _, err := os.Stat(rulesRoot); err != nil { + if os.Getenv("PLATFORM_PARITY_SMOKE_REQUIRED") == "1" { + t.Fatalf("platform rules not found at %s", rulesRoot) + } + t.Skipf("platform rules not found at %s", rulesRoot) + } + entries, err := collectRuleMappings(rulesRoot) + if err != nil { + t.Fatalf("collect mappings: %v", err) + } + if len(entries) == 0 { + t.Fatalf("no dql->route mappings found under %s", rulesRoot) + } + entryBySource := map[string]parityRule{} + for _, entry := range entries { + entryBySource[entry.Source] = entry + } + highRiskHandlers := collectSmokeHandlerSources(entries, routesRoot) + if len(highRiskHandlers) < 5 { + t.Fatalf("smoke handler discovery returned too few sources: %d", len(highRiskHandlers)) + } + + compiler := shapecompile.New() + for _, source := range highRiskHandlers { + entry, ok := entryBySource[source] + if !ok { + t.Fatalf("smoke source not found in rule mappings: %s", source) + } + eval := evaluateParityEntry(platformRoot, routesRoot, entry, compiler) + if !eval.SourceReadable { + t.Fatalf("unable to read source for smoke source: %s", source) + } + if eval.MissingLegacyYAML { + t.Fatalf("missing legacy yaml for smoke source: %s", source) + } + out := eval.Output + if out.CompileFailed { + t.Fatalf("shape compile failed for %s: %v", source, out.ShapeDiags) + } + if len(out.Mismatches) > 0 { + t.Fatalf("parity mismatches for %s: %v", source, out.Mismatches) + } + } +} + +func TestPlatform_DQLToRoute_ParityIR(t *testing.T) { + platformRoot := os.Getenv("PLATFORM_ROOT") + if platformRoot == "" { + platformRoot = "/Users/awitas/go/src/github.vianttech.com/viant/platform" + } + rulesRoot := filepath.Join(platformRoot, "e2e", "rule") + routesRoot := filepath.Join(platformRoot, "repo", "dev", "Datly", "routes") + if _, err := os.Stat(rulesRoot); err != nil { + t.Skipf("platform rules not found at %s", rulesRoot) + } + entries, err := collectRuleMappings(rulesRoot) + if err != nil { + t.Fatalf("collect mappings: %v", err) + } + if len(entries) == 0 { + t.Fatalf("no dql->route mappings found under %s", rulesRoot) + } + targetSource := strings.TrimSpace(os.Getenv("PLATFORM_PARITY_SOURCE")) + runAll := strings.EqualFold(targetSource, "all") || targetSource == "*" || strings.EqualFold(strings.TrimSpace(os.Getenv("PLATFORM_PARITY_ALL")), "1") + if targetSource == "" && !runAll { + t.Skip("set PLATFORM_PARITY_SOURCE to run transient platform parity check") + } + if !runAll { + var filtered []parityRule + for _, entry := range entries { + if entry.Source == targetSource { + filtered = append(filtered, entry) + } + } + if len(filtered) == 0 { + t.Fatalf("target source not found in rules: %s", targetSource) + } + entries = filtered + } + + compiler := shapecompile.New() + report := parityReport{Total: len(entries)} + issueCounts := map[string]int{} + + for _, entry := range entries { + eval := evaluateParityEntry(platformRoot, routesRoot, entry, compiler) + if !eval.SourceReadable { + continue + } + if eval.MissingLegacyYAML { + report.MissingYAML++ + continue + } + report.Compared++ + out := eval.Output + routeYAMLPath := out.LegacyYAML + if out.CompileFailed { + issueCounts["shape compile failed"]++ + report.Failures++ + writeIRFile(routeYAMLPath+".shape.ir.yaml", out) + report.WithDiff++ + continue + } + if len(out.Mismatches) > 0 { + report.WithDiff++ + for _, m := range out.Mismatches { + issueCounts[m]++ + } + } + writeIRFile(routeYAMLPath+".shape.ir.yaml", out) + } + + report.TopIssues = topIssues(issueCounts, 10) + reportPath := filepath.Join(routesRoot, "_shape_parity_report.yaml") + writeYAML(reportPath, report) + t.Logf("parity report: %s", reportPath) + t.Logf("total=%d compared=%d withDiff=%d missingYaml=%d failures=%d", report.Total, report.Compared, report.WithDiff, report.MissingYAML, report.Failures) +} + +func collectSmokeHandlerSources(entries []parityRule, routesRoot string) []string { + excluded := map[string]bool{} + var result []string + for _, entry := range entries { + source := strings.TrimSpace(entry.Source) + if !isHandlerLikeSource(source) { + continue + } + if excluded[source] { + continue + } + routeYAMLPath := filepath.Join(routesRoot, entry.Namespace, routeYAMLName(source)) + if _, err := os.Stat(routeYAMLPath); err != nil { + continue + } + result = append(result, source) + } + sort.Strings(result) + return dedupe(result) +} + +func isHandlerLikeSource(source string) bool { + source = strings.ToLower(strings.TrimSpace(source)) + if source == "" { + return false + } + if strings.Contains(source, "/gen/") && (strings.HasSuffix(source, ".dql") || strings.HasSuffix(source, ".sql")) { + return true + } + return strings.HasSuffix(source, "/patch.dql") || + strings.HasSuffix(source, "/patch.sql") || + strings.HasSuffix(source, "/post.dql") || + strings.HasSuffix(source, "/post.sql") || + strings.HasSuffix(source, "/put.dql") || + strings.HasSuffix(source, "/put.sql") || + strings.HasSuffix(source, "/delete.dql") || + strings.HasSuffix(source, "/delete.sql") || + strings.HasSuffix(source, "/upload.dql") || + strings.HasSuffix(source, "/upload.sql") || + strings.HasSuffix(source, "/export.dql") || + strings.HasSuffix(source, "/export.sql") || + strings.HasSuffix(source, "/action.dql") || + strings.HasSuffix(source, "/action.sql") +} + +func evaluateParityEntry(platformRoot, routesRoot string, entry parityRule, compiler *shapecompile.DQLCompiler) parityEntryEval { + sourcePath := filepath.Join(platformRoot, entry.Source) + routeYAMLPath, _ := resolveLegacyRouteYAMLPath(routesRoot, entry.Namespace, entry.Source) + if routeYAMLPath == "" { + routeYAMLPath = filepath.Join(routesRoot, entry.Namespace, routeYAMLName(entry.Source)) + } + out := parityEntryEval{Output: parityOutput{ + Namespace: entry.Namespace, + Source: entry.Source, + LegacyYAML: routeYAMLPath, + }} + sourceBytes, readErr := os.ReadFile(sourcePath) + if readErr != nil { + return out + } + out.SourceReadable = true + legacyBytes, legacyErr := os.ReadFile(routeYAMLPath) + if legacyErr != nil { + out.MissingLegacyYAML = true + return out + } + sourceName := strings.TrimSuffix(filepath.Base(sourcePath), filepath.Ext(sourcePath)) + if sourceName == "" { + sourceName = entry.Namespace + } + + var legacy legacyYAML + if err := yaml.Unmarshal(legacyBytes, &legacy); err == nil { + out.Output.LegacyMeta = &resourceMetaIR{ColumnsDiscovery: legacy.ColumnsDiscovery} + out.Output.LegacyViews = make([]viewIR, 0, len(legacy.Resource.Views)) + out.Output.LegacyViewMeta = make([]viewMetaIR, 0, len(legacy.Resource.Views)) + for _, v := range legacy.Resource.Views { + out.Output.LegacyViews = append(out.Output.LegacyViews, viewIR{ + Name: v.Name, + Table: v.Table, + Connector: v.Connector.Ref, + SQLURI: v.Template.SourceURL, + }) + var hasSummary *bool + if v.Template.Summary != nil { + value := true + hasSummary = &value + } + out.Output.LegacyViewMeta = append(out.Output.LegacyViewMeta, viewMetaIR{ + Name: strings.TrimSpace(v.Name), + Mode: strings.TrimSpace(v.Mode), + Module: strings.TrimSpace(v.Module), + AllowNulls: v.AllowNulls, + SelectorNamespace: strings.TrimSpace(v.Selector.Namespace), + SelectorNoLimit: v.Selector.NoLimit, + SchemaCardinality: strings.TrimSpace(v.Schema.Cardinality), + SchemaType: firstNonEmpty(strings.TrimSpace(v.Schema.DataType), strings.TrimSpace(v.Schema.Name)), + HasSummary: hasSummary, + }) + } + for _, r := range legacy.Routes { + out.Output.LegacyRoutes = append(out.Output.LegacyRoutes, routeIR{ + Method: r.Method, + URI: r.URI, + View: r.View.Ref, + }) + } + out.Output.LegacyTypeCtx = normalizeTypeContextIR( + legacy.TypeContext.DefaultPackage, + legacy.TypeContext.PackageDir, + legacy.TypeContext.PackageName, + legacy.TypeContext.PackagePath, + ) + out.Output.LegacyParams = normalizeLegacyParams(legacy) + out.Output.LegacyTypes = normalizeLegacyTypes(legacy) + } + + planResult, compileErr := compiler.Compile(context.Background(), &shape.Source{ + Name: sourceName, + Path: sourcePath, + Connector: entry.Connector, + DQL: string(sourceBytes), + }) + if compileErr != nil { + out.Output.CompileFailed = true + if cErr, ok := compileErr.(*shapecompile.CompileError); ok { + out.Output.RawDiagnostics = cErr.Diagnostics + for _, d := range cErr.Diagnostics { + if d == nil { + continue + } + out.Output.ShapeDiags = append(out.Output.ShapeDiags, d.Error()) + } + } else { + out.Output.ShapeDiags = append(out.Output.ShapeDiags, compileErr.Error()) + } + out.Output.Mismatches = append(out.Output.Mismatches, "shape compile failed") + return out + } + + planned, _ := plan.ResultFrom(planResult) + if planned != nil { + out.Output.ShapeMeta = &resourceMetaIR{} + if sourcePath != "" { + value := true + out.Output.ShapeMeta.ColumnsDiscovery = &value + } + out.Output.ShapeViews = make([]viewIR, 0, len(planned.Views)) + out.Output.ShapeViewMeta = make([]viewMetaIR, 0, len(planned.Views)) + for _, v := range planned.Views { + if v == nil { + continue + } + out.Output.ShapeViews = append(out.Output.ShapeViews, viewIR{ + Name: v.Name, + Table: v.Table, + Connector: v.Connector, + SQLURI: v.SQLURI, + }) + var hasSummary *bool + if strings.TrimSpace(v.Summary) != "" { + value := true + hasSummary = &value + } + out.Output.ShapeViewMeta = append(out.Output.ShapeViewMeta, viewMetaIR{ + Name: strings.TrimSpace(v.Name), + Mode: inferShapeViewMode(v.SQL), + Module: strings.TrimSpace(v.Module), + AllowNulls: v.AllowNulls, + SelectorNamespace: strings.TrimSpace(v.SelectorNamespace), + SelectorNoLimit: v.SelectorNoLimit, + SchemaCardinality: normalizeCardinality(strings.TrimSpace(v.Cardinality)), + SchemaType: strings.TrimSpace(v.SchemaType), + HasSummary: hasSummary, + }) + } + for _, d := range planned.Diagnostics { + if d == nil { + continue + } + out.Output.ShapeDiags = append(out.Output.ShapeDiags, d.Error()) + } + loader := shapeload.New() + if artifacts, err := loader.LoadViews(context.Background(), planResult); err == nil && artifacts != nil && artifacts.Resource != nil { + mergeShapeViewMetadata(out.Output.ShapeViewMeta, artifacts.Resource.Views) + } + out.Output.ShapeParams = normalizeShapeParams(planned) + out.Output.ShapeTypes = normalizeShapeTypes(planned, sourcePath) + if planned.TypeContext != nil { + out.Output.ShapeTypeCtx = normalizeTypeContextIR( + planned.TypeContext.DefaultPackage, + planned.TypeContext.PackageDir, + planned.TypeContext.PackageName, + planned.TypeContext.PackagePath, + ) + } + } + + out.Output.Mismatches = compareParity(out.Output.LegacyViews, out.Output.ShapeViews) + out.Output.Mismatches = append(out.Output.Mismatches, compareMetadataParity(out.Output.LegacyMeta, out.Output.ShapeMeta, out.Output.LegacyViewMeta, out.Output.ShapeViewMeta)...) + out.Output.Mismatches = append(out.Output.Mismatches, compareParamParity(out.Output.LegacyParams, out.Output.ShapeParams)...) + out.Output.Mismatches = append(out.Output.Mismatches, compareTypeParity(out.Output.LegacyTypes, out.Output.ShapeTypes)...) + out.Output.Mismatches = append(out.Output.Mismatches, compareTypeContextParity(out.Output.LegacyTypeCtx, out.Output.ShapeTypeCtx)...) + out.Output.Mismatches = dedupe(out.Output.Mismatches) + return out +} + +func resolveLegacyRouteYAMLPath(routesRoot, namespace, source string) (string, bool) { + candidates := legacyRouteYAMLCandidatePaths(routesRoot, namespace, source) + for _, candidate := range candidates { + if _, err := os.Stat(candidate); err == nil { + return candidate, true + } + } + return "", false +} + +func legacyRouteYAMLCandidatePaths(routesRoot, namespace, source string) []string { + namespace = strings.Trim(strings.TrimSpace(namespace), "/") + stem := strings.TrimSuffix(filepath.Base(strings.TrimSpace(source)), filepath.Ext(strings.TrimSpace(source))) + if stem == "" { + stem = "route" + } + fileName := stem + ".yaml" + nsPath := filepath.FromSlash(namespace) + leaf := filepath.Base(nsPath) + parent := filepath.Dir(nsPath) + + appendUnique := func(items *[]string, seen map[string]bool, path string) { + path = filepath.Clean(path) + if path == "." || path == "" || seen[path] { + return + } + seen[path] = true + *items = append(*items, path) + } + + seen := map[string]bool{} + result := make([]string, 0, 8) + appendUnique(&result, seen, filepath.Join(routesRoot, nsPath, fileName)) + appendUnique(&result, seen, filepath.Join(routesRoot, nsPath, stem, fileName)) + if leaf != "" && leaf != "." { + appendUnique(&result, seen, filepath.Join(routesRoot, nsPath, leaf+".yaml")) + } + if parent != "" && parent != "." { + appendUnique(&result, seen, filepath.Join(routesRoot, parent, fileName)) + appendUnique(&result, seen, filepath.Join(routesRoot, parent, stem, fileName)) + parentLeaf := filepath.Base(parent) + if parentLeaf != "" && parentLeaf != "." { + appendUnique(&result, seen, filepath.Join(routesRoot, parent, parentLeaf+".yaml")) + } + } + if strings.Contains(strings.ToLower(source), "/gen/") { + appendUnique(&result, seen, filepath.Join(routesRoot, nsPath, "patch", "patch.yaml")) + } + return result +} + +func collectRuleMappings(rulesRoot string) ([]parityRule, error) { + var files []string + if err := filepath.WalkDir(rulesRoot, func(path string, d os.DirEntry, err error) error { + if err != nil { + return err + } + if d.IsDir() || !strings.HasSuffix(path, ".yaml") { + return nil + } + files = append(files, path) + return nil + }); err != nil { + return nil, err + } + re := regexp.MustCompile(`\$appPath/bin/datly\s+(gen|translate)\s+.*-u=([^\s]+)\s+-s='([^']+)'(.*)`) + seen := map[string]bool{} + var result []parityRule + for _, file := range files { + data, err := os.ReadFile(file) + if err != nil { + continue + } + lines := strings.Split(string(data), "\n") + for _, line := range lines { + trimmed := strings.TrimSpace(line) + if trimmed == "" || strings.HasPrefix(trimmed, "#") { + continue + } + m := re.FindStringSubmatch(line) + if len(m) < 4 { + continue + } + src := strings.TrimSpace(m[3]) + if !(strings.HasSuffix(src, ".dql") || strings.HasSuffix(src, ".sql")) { + continue + } + connector := inferRuleConnector("") + if len(m) >= 5 { + connector = inferRuleConnector(m[4]) + } + key := m[2] + "|" + src + if seen[key] { + continue + } + seen[key] = true + result = append(result, parityRule{ + Mode: strings.TrimSpace(m[1]), + Namespace: strings.TrimSpace(m[2]), + Source: src, + Connector: connector, + }) + } + } + sort.Slice(result, func(i, j int) bool { + if result[i].Namespace == result[j].Namespace { + return result[i].Source < result[j].Source + } + return result[i].Namespace < result[j].Namespace + }) + return result, nil +} + +func inferRuleConnector(tail string) string { + lower := strings.ToLower(tail) + switch { + case strings.Contains(lower, "$optionsaero"): + return "system" + case strings.Contains(lower, "$optionssitemgmt"): + return "sitemgmt" + case strings.Contains(lower, "$options"): + return "ci_ads" + default: + return "" + } +} + +func routeYAMLName(source string) string { + base := filepath.Base(source) + ext := filepath.Ext(base) + return strings.TrimSuffix(base, ext) + ".yaml" +} + +func compareParity(legacy, shapeViews []viewIR) []string { + var result []string + if len(legacy) != len(shapeViews) { + result = append(result, "view count mismatch") + } + legacyByName := map[string]viewIR{} + for _, v := range legacy { + legacyByName[strings.ToLower(v.Name)] = v + } + for _, s := range shapeViews { + l, ok := legacyByName[strings.ToLower(s.Name)] + if !ok { + result = append(result, "missing view in legacy: "+s.Name) + continue + } + if l.Table != "" && s.Table != "" && !strings.EqualFold(l.Table, s.Table) { + result = append(result, "table mismatch: "+s.Name) + } + if l.Connector != "" && s.Connector == "" { + result = append(result, "connector missing in shape: "+s.Name) + } + if l.Connector != "" && s.Connector != "" && !strings.EqualFold(strings.TrimSpace(l.Connector), strings.TrimSpace(s.Connector)) { + result = append(result, "connector mismatch: "+s.Name) + } + if l.SQLURI != "" && s.SQLURI == "" { + result = append(result, "sql uri missing in shape: "+s.Name) + } + if l.SQLURI != "" && s.SQLURI != "" && !equalSQLURI(l.SQLURI, s.SQLURI) { + result = append(result, "sql uri mismatch: "+s.Name) + } + } + return dedupe(result) +} + +func equalSQLURI(legacy, shape string) bool { + normalize := func(v string) string { + v = strings.ReplaceAll(strings.TrimSpace(v), "\\", "/") + return strings.TrimPrefix(v, "./") + } + return strings.EqualFold(normalize(legacy), normalize(shape)) +} + +func normalizeCardinality(value string) string { + switch strings.ToLower(strings.TrimSpace(value)) { + case "one": + return "One" + case "many": + return "Many" + default: + return strings.TrimSpace(value) + } +} + +func inferShapeViewMode(sql string) string { + sql = strings.TrimSpace(sql) + if sql == "" { + return "" + } + statements := dqlstmt.New(sql) + hasRead := false + hasExec := false + for _, item := range statements { + if item == nil { + continue + } + switch item.Kind { + case dqlstmt.KindRead: + hasRead = true + case dqlstmt.KindExec: + hasExec = true + } + } + switch { + case hasRead && !hasExec: + return "SQLQuery" + case hasExec && !hasRead: + return "SQLExec" + case hasRead && hasExec: + return "SQLExec" + } + stmt := strings.ToLower(sql) + if strings.HasPrefix(stmt, "select") || strings.HasPrefix(stmt, "with") { + return "SQLQuery" + } + return "" +} + +func mergeShapeViewMetadata(meta []viewMetaIR, views view.Views) { + if len(meta) == 0 || len(views) == 0 { + return + } + index := map[string]int{} + for i, item := range meta { + index[strings.ToLower(strings.TrimSpace(item.Name))] = i + } + for _, candidate := range views { + if candidate == nil { + continue + } + key := strings.ToLower(strings.TrimSpace(candidate.Name)) + pos, ok := index[key] + if !ok { + continue + } + if mode := strings.TrimSpace(string(candidate.Mode)); mode != "" { + meta[pos].Mode = mode + } + if meta[pos].Module == "" { + meta[pos].Module = strings.TrimSpace(candidate.Module) + } + if meta[pos].AllowNulls == nil { + meta[pos].AllowNulls = candidate.AllowNulls + } + if candidate.Selector != nil { + if meta[pos].SelectorNamespace == "" { + meta[pos].SelectorNamespace = strings.TrimSpace(candidate.Selector.Namespace) + } + if meta[pos].SelectorNoLimit == nil { + meta[pos].SelectorNoLimit = &candidate.Selector.NoLimit + } + } + if candidate.Schema != nil { + if meta[pos].SchemaCardinality == "" { + meta[pos].SchemaCardinality = strings.TrimSpace(string(candidate.Schema.Cardinality)) + } + if meta[pos].SchemaType == "" { + meta[pos].SchemaType = firstNonEmpty(strings.TrimSpace(candidate.Schema.DataType), strings.TrimSpace(candidate.Schema.Name)) + } + } + if candidate.Template != nil && candidate.Template.Summary != nil { + value := true + meta[pos].HasSummary = &value + } + } +} + +func compareMetadataParity(legacyMeta, shapeMeta *resourceMetaIR, legacyViews, shapeViews []viewMetaIR) []string { + var result []string + if legacyMeta != nil && legacyMeta.ColumnsDiscovery != nil { + if shapeMeta == nil || shapeMeta.ColumnsDiscovery == nil { + result = append(result, "resource columnsDiscovery missing in shape") + } else if *legacyMeta.ColumnsDiscovery != *shapeMeta.ColumnsDiscovery { + result = append(result, "resource columnsDiscovery mismatch") + } + } + legacyByName := map[string]viewMetaIR{} + for _, item := range legacyViews { + legacyByName[strings.ToLower(strings.TrimSpace(item.Name))] = item + } + for _, shapeItem := range shapeViews { + key := strings.ToLower(strings.TrimSpace(shapeItem.Name)) + legacyItem, ok := legacyByName[key] + if !ok { + continue + } + if legacyItem.Mode != "" { + if shapeItem.Mode == "" { + result = append(result, "view mode missing in shape: "+shapeItem.Name) + } else if !strings.EqualFold(legacyItem.Mode, shapeItem.Mode) { + result = append(result, "view mode mismatch: "+shapeItem.Name) + } + } + if legacyItem.Module != "" { + if shapeItem.Module == "" { + result = append(result, "view module missing in shape: "+shapeItem.Name) + } else if !strings.EqualFold(strings.TrimSpace(legacyItem.Module), strings.TrimSpace(shapeItem.Module)) { + result = append(result, "view module mismatch: "+shapeItem.Name) + } + } + if legacyItem.AllowNulls != nil { + if shapeItem.AllowNulls == nil { + result = append(result, "view allowNulls missing in shape: "+shapeItem.Name) + } else if *legacyItem.AllowNulls != *shapeItem.AllowNulls { + result = append(result, "view allowNulls mismatch: "+shapeItem.Name) + } + } + if legacyItem.SelectorNamespace != "" { + if shapeItem.SelectorNamespace == "" { + result = append(result, "view selector namespace missing in shape: "+shapeItem.Name) + } else if !strings.EqualFold(strings.TrimSpace(legacyItem.SelectorNamespace), strings.TrimSpace(shapeItem.SelectorNamespace)) { + result = append(result, "view selector namespace mismatch: "+shapeItem.Name) + } + } + if legacyItem.SelectorNoLimit != nil { + if shapeItem.SelectorNoLimit == nil { + result = append(result, "view selector noLimit missing in shape: "+shapeItem.Name) + } else if *legacyItem.SelectorNoLimit != *shapeItem.SelectorNoLimit { + result = append(result, "view selector noLimit mismatch: "+shapeItem.Name) + } + } + if legacyItem.SchemaCardinality != "" { + if shapeItem.SchemaCardinality == "" { + result = append(result, "view schema cardinality missing in shape: "+shapeItem.Name) + } else if !strings.EqualFold(strings.TrimSpace(legacyItem.SchemaCardinality), strings.TrimSpace(shapeItem.SchemaCardinality)) { + result = append(result, "view schema cardinality mismatch: "+shapeItem.Name) + } + } + if legacyItem.SchemaType != "" { + if shapeItem.SchemaType == "" { + result = append(result, "view schema type missing in shape: "+shapeItem.Name) + } else if !strings.EqualFold(strings.TrimSpace(legacyItem.SchemaType), strings.TrimSpace(shapeItem.SchemaType)) { + result = append(result, "view schema type mismatch: "+shapeItem.Name) + } + } + if legacyItem.HasSummary != nil { + if shapeItem.HasSummary == nil { + result = append(result, "view template summary missing in shape: "+shapeItem.Name) + } else if *legacyItem.HasSummary != *shapeItem.HasSummary { + result = append(result, "view template summary mismatch: "+shapeItem.Name) + } + } + } + return dedupe(result) +} + +func firstNonEmpty(values ...string) string { + for _, value := range values { + value = strings.TrimSpace(value) + if value != "" { + return value + } + } + return "" +} + +func normalizeLegacyParams(legacy legacyYAML) []paramIR { + querySelectors := map[string]string{} + querySelectorCacheable := map[string]*bool{} + querySelectorIn := map[string]string{} + for _, v := range legacy.Resource.Views { + viewName := strings.TrimSpace(v.Name) + for _, param := range []selectorParam{v.Selector.LimitParameter, v.Selector.OffsetParameter, v.Selector.PageParameter, v.Selector.FieldsParameter, v.Selector.OrderByParameter} { + name := strings.TrimSpace(param.Name) + if name == "" || viewName == "" { + continue + } + querySelectors[strings.ToLower(name)] = viewName + querySelectorIn[strings.ToLower(name)] = strings.TrimSpace(param.In.Name) + if param.Cacheable != nil { + value := *param.Cacheable + querySelectorCacheable[strings.ToLower(name)] = &value + } + } + } + result := make([]paramIR, 0, len(legacy.Resource.Parameters)) + seen := map[string]bool{} + for _, p := range legacy.Resource.Parameters { + name := strings.TrimSpace(p.Name) + item := paramIR{ + Name: name, + Kind: strings.TrimSpace(p.In.Kind), + In: strings.TrimSpace(p.In.Name), + Required: p.Required, + Cacheable: p.Cacheable, + URI: strings.TrimSpace(p.URI), + Value: strings.TrimSpace(p.Value), + } + if selector, ok := querySelectors[strings.ToLower(name)]; ok { + item.QuerySelector = selector + if item.Cacheable == nil { + item.Cacheable = querySelectorCacheable[strings.ToLower(name)] + } + } + for _, pred := range p.Predicates { + item.Predicates = append(item.Predicates, normalizePredicateSig(pred.Group, pred.Name, pred.Ensure, pred.Args)) + } + sort.Strings(item.Predicates) + result = append(result, item) + seen[strings.ToLower(name)] = true + } + for key, selector := range querySelectors { + if seen[key] { + continue + } + name := strings.TrimSpace(key) + if name == "" { + continue + } + legacyName := name + for _, v := range legacy.Resource.Views { + for _, param := range []selectorParam{v.Selector.LimitParameter, v.Selector.OffsetParameter, v.Selector.PageParameter, v.Selector.FieldsParameter, v.Selector.OrderByParameter} { + if strings.EqualFold(strings.TrimSpace(param.Name), key) { + legacyName = strings.TrimSpace(param.Name) + break + } + } + } + result = append(result, paramIR{ + Name: legacyName, + Kind: "query", + In: strings.TrimSpace(querySelectorIn[key]), + QuerySelector: selector, + Cacheable: querySelectorCacheable[key], + }) + } + sort.Slice(result, func(i, j int) bool { + if strings.EqualFold(result[i].Name, result[j].Name) { + if strings.EqualFold(result[i].Kind, result[j].Kind) { + return strings.ToLower(result[i].In) < strings.ToLower(result[j].In) + } + return strings.ToLower(result[i].Kind) < strings.ToLower(result[j].Kind) + } + return strings.ToLower(result[i].Name) < strings.ToLower(result[j].Name) + }) + return result +} + +func normalizeLegacyTypes(legacy legacyYAML) []typeIR { + if len(legacy.Resource.Types) == 0 { + return nil + } + result := make([]typeIR, 0, len(legacy.Resource.Types)) + seen := map[string]bool{} + for _, item := range legacy.Resource.Types { + name := strings.TrimSpace(item.Name) + if name == "" { + continue + } + key := strings.ToLower(name) + if seen[key] { + continue + } + seen[key] = true + result = append(result, typeIR{ + Name: name, + Alias: strings.TrimSpace(item.Alias), + DataType: normalizeTypeSignature(item.DataType), + Cardinality: normalizeCardinality(strings.TrimSpace(item.Cardinality)), + Package: strings.TrimSpace(item.Package), + ModulePath: strings.TrimSpace(item.ModulePath), + }) + } + sort.Slice(result, func(i, j int) bool { + return strings.ToLower(result[i].Name) < strings.ToLower(result[j].Name) + }) + return result +} + +func normalizeShapeTypes(planned *plan.Result, sourcePath string) []typeIR { + if planned == nil { + return nil + } + modulePrefix := inferModulePrefix(sourcePath) + typeImportByAlias, typeImportByPkg := typeImports(planned) + byName := map[string]typeIR{} + + register := func(item typeIR, overwrite bool) { + name := strings.TrimSpace(item.Name) + if name == "" { + return + } + key := strings.ToLower(name) + if existing, ok := byName[key]; ok { + if (overwrite || existing.DataType == "") && item.DataType != "" { + existing.DataType = item.DataType + } + if (overwrite || existing.Cardinality == "") && item.Cardinality != "" { + existing.Cardinality = item.Cardinality + } + if (overwrite || existing.Package == "") && item.Package != "" { + existing.Package = item.Package + } + if (overwrite || existing.ModulePath == "") && item.ModulePath != "" { + existing.ModulePath = item.ModulePath + } + if overwrite && item.Alias != "" { + existing.Alias = item.Alias + } + byName[key] = existing + return + } + byName[key] = item + } + + for _, item := range planned.Views { + if item == nil { + continue + } + dataType := strings.TrimSpace(item.SchemaType) + name := typeNameFromDataType(dataType) + if name == "" && item.ElementType != nil { + name = strings.TrimSpace(item.ElementType.Name()) + if dataType == "" && name != "" { + dataType = "*" + name + } + } + if name == "" { + continue + } + pkg := packageFromDataType(dataType) + modulePath := "" + if strings.TrimSpace(item.Module) != "" && modulePrefix != "" { + modulePath = modulePrefix + strings.Trim(strings.TrimSpace(item.Module), "/") + } + if modulePath == "" && pkg != "" { + modulePath = firstNonEmpty(typeImportByAlias[strings.ToLower(pkg)], typeImportByPkg[strings.ToLower(pkg)]) + } + register(typeIR{ + Name: name, + DataType: normalizeTypeSignature(dataType), + Cardinality: normalizeCardinality(strings.TrimSpace(item.Cardinality)), + Package: pkg, + ModulePath: modulePath, + }, false) + } + + for _, item := range planned.States { + if item == nil || item.Schema == nil || strings.TrimSpace(item.Schema.DataType) == "" { + continue + } + dataType := strings.TrimSpace(item.Schema.DataType) + name := typeNameFromDataType(dataType) + if name == "" { + continue + } + pkg := packageFromDataType(dataType) + modulePath := firstNonEmpty(typeImportByAlias[strings.ToLower(pkg)], typeImportByPkg[strings.ToLower(pkg)]) + register(typeIR{ + Name: name, + DataType: normalizeTypeSignature(dataType), + Package: pkg, + ModulePath: modulePath, + }, false) + } + for _, item := range planned.Types { + if item == nil || strings.TrimSpace(item.Name) == "" { + continue + } + register(typeIR{ + Name: strings.TrimSpace(item.Name), + Alias: strings.TrimSpace(item.Alias), + DataType: normalizeTypeSignature(item.DataType), + Cardinality: normalizeCardinality(strings.TrimSpace(item.Cardinality)), + Package: strings.TrimSpace(item.Package), + ModulePath: strings.TrimSpace(item.ModulePath), + }, true) + } + + result := make([]typeIR, 0, len(byName)) + for _, item := range byName { + result = append(result, item) + } + sort.Slice(result, func(i, j int) bool { + return strings.ToLower(result[i].Name) < strings.ToLower(result[j].Name) + }) + return result +} + +func compareTypeParity(legacy, shapeTypes []typeIR) []string { + var result []string + if len(legacy) == 0 { + return nil + } + shapeByName := map[string]typeIR{} + for _, item := range shapeTypes { + shapeByName[strings.ToLower(strings.TrimSpace(item.Name))] = item + } + for _, legacyType := range legacy { + key := strings.ToLower(strings.TrimSpace(legacyType.Name)) + shapeType, ok := shapeByName[key] + if !ok { + result = append(result, "missing type in shape: "+legacyType.Name) + continue + } + if legacyType.DataType != "" && shapeType.DataType != "" && legacyType.DataType != shapeType.DataType { + result = append(result, "type dataType mismatch: "+legacyType.Name) + } + if legacyType.Cardinality != "" && shapeType.Cardinality != "" && !strings.EqualFold(legacyType.Cardinality, shapeType.Cardinality) { + result = append(result, "type cardinality mismatch: "+legacyType.Name) + } + if legacyType.Package != "" && shapeType.Package != "" && !strings.EqualFold(legacyType.Package, shapeType.Package) { + result = append(result, "type package mismatch: "+legacyType.Name) + } + if legacyType.ModulePath != "" && shapeType.ModulePath != "" && !strings.EqualFold(legacyType.ModulePath, shapeType.ModulePath) { + result = append(result, "type module path mismatch: "+legacyType.Name) + } + if legacyType.Alias != "" && shapeType.Alias != "" && !strings.EqualFold(legacyType.Alias, shapeType.Alias) { + result = append(result, "type alias mismatch: "+legacyType.Name) + } + } + return dedupe(result) +} + +func normalizeTypeContextIR(defaultPackage, packageDir, packageName, packagePath string) *typeCtxIR { + ret := &typeCtxIR{ + DefaultPackage: strings.TrimSpace(defaultPackage), + PackageDir: strings.TrimSpace(packageDir), + PackageName: strings.TrimSpace(packageName), + PackagePath: strings.TrimSpace(packagePath), + } + if ret.DefaultPackage == "" && ret.PackageDir == "" && ret.PackageName == "" && ret.PackagePath == "" { + return nil + } + return ret +} + +func compareTypeContextParity(legacy, shape *typeCtxIR) []string { + if legacy == nil { + return nil + } + if shape == nil { + return []string{"missing type context in shape"} + } + var result []string + if legacy.DefaultPackage != "" && shape.DefaultPackage != "" && !strings.EqualFold(legacy.DefaultPackage, shape.DefaultPackage) { + result = append(result, "type context default package mismatch") + } + if legacy.PackageDir != "" && shape.PackageDir != "" && !strings.EqualFold(legacy.PackageDir, shape.PackageDir) { + result = append(result, "type context package dir mismatch") + } + if legacy.PackageName != "" && shape.PackageName != "" && !strings.EqualFold(legacy.PackageName, shape.PackageName) { + result = append(result, "type context package name mismatch") + } + if legacy.PackagePath != "" && shape.PackagePath != "" && !strings.EqualFold(legacy.PackagePath, shape.PackagePath) { + result = append(result, "type context package path mismatch") + } + return dedupe(result) +} + +func normalizeTypeSignature(value string) string { + value = strings.TrimSpace(value) + if value == "" { + return "" + } + parts := strings.Fields(value) + return strings.Join(parts, " ") +} + +func typeNameFromDataType(dataType string) string { + dataType = strings.TrimSpace(dataType) + if dataType == "" { + return "" + } + dataType = strings.TrimLeft(dataType, "*[]") + if dataType == "" { + return "" + } + if idx := strings.LastIndex(dataType, "."); idx != -1 { + dataType = dataType[idx+1:] + } + if idx := strings.Index(dataType, "{"); idx != -1 { + dataType = dataType[:idx] + } + return strings.TrimSpace(dataType) +} + +func packageFromDataType(dataType string) string { + dataType = strings.TrimSpace(dataType) + dataType = strings.TrimLeft(dataType, "*[]") + if idx := strings.LastIndex(dataType, "."); idx != -1 { + return strings.TrimSpace(dataType[:idx]) + } + return "" +} + +func inferModulePrefix(sourcePath string) string { + normalized := filepath.ToSlash(strings.TrimSpace(sourcePath)) + if normalized == "" { + return "" + } + const marker = "/src/" + idx := strings.Index(normalized, marker) + if idx == -1 { + return "" + } + root := normalized[idx+len(marker):] + if slash := strings.Index(root, "/dql/"); slash != -1 { + root = root[:slash] + } + root = strings.Trim(root, "/") + if root == "" { + return "" + } + return root + "/pkg/" +} + +func typeImports(planned *plan.Result) (map[string]string, map[string]string) { + byAlias := map[string]string{} + byPkg := map[string]string{} + if planned == nil || planned.TypeContext == nil { + return byAlias, byPkg + } + appendPkg := func(pkg string) { + pkg = strings.TrimSpace(pkg) + if pkg == "" { + return + } + base := pkg + if idx := strings.LastIndex(base, "/"); idx != -1 { + base = base[idx+1:] + } + base = strings.ToLower(strings.TrimSpace(base)) + if base != "" { + byPkg[base] = pkg + } + } + if packagePath := strings.TrimSpace(planned.TypeContext.PackagePath); packagePath != "" { + appendPkg(packagePath) + if pkgName := strings.ToLower(strings.TrimSpace(planned.TypeContext.PackageName)); pkgName != "" { + byAlias[pkgName] = packagePath + byPkg[pkgName] = packagePath + } + } + appendPkg(planned.TypeContext.DefaultPackage) + for _, item := range planned.TypeContext.Imports { + pkg := strings.TrimSpace(item.Package) + if pkg == "" { + continue + } + if alias := strings.ToLower(strings.TrimSpace(item.Alias)); alias != "" { + byAlias[alias] = pkg + } + appendPkg(pkg) + } + return byAlias, byPkg +} + +func normalizeShapeParams(planned *plan.Result) []paramIR { + if planned == nil || len(planned.States) == 0 { + return nil + } + result := make([]paramIR, 0, len(planned.States)) + for _, s := range planned.States { + if s == nil { + continue + } + item := paramIR{ + Name: strings.TrimSpace(s.Name), + Kind: strings.TrimSpace(s.KindString()), + In: strings.TrimSpace(s.InName()), + Required: s.Required, + Cacheable: s.Cacheable, + URI: strings.TrimSpace(s.URI), + Value: strings.TrimSpace(fmt.Sprint(s.Value)), + QuerySelector: strings.TrimSpace(s.QuerySelector), + } + for _, pred := range s.Predicates { + if pred == nil { + continue + } + item.Predicates = append(item.Predicates, normalizePredicateSig(pred.Group, pred.Name, pred.Ensure, pred.Args)) + } + sort.Strings(item.Predicates) + result = append(result, item) + } + sort.Slice(result, func(i, j int) bool { + if strings.EqualFold(result[i].Name, result[j].Name) { + if strings.EqualFold(result[i].Kind, result[j].Kind) { + return strings.ToLower(result[i].In) < strings.ToLower(result[j].In) + } + return strings.ToLower(result[i].Kind) < strings.ToLower(result[j].Kind) + } + return strings.ToLower(result[i].Name) < strings.ToLower(result[j].Name) + }) + return result +} + +func normalizePredicateSig(group int, name string, ensure bool, args []string) string { + parts := make([]string, 0, len(args)) + for _, arg := range args { + parts = append(parts, strings.TrimSpace(arg)) + } + return strings.ToLower(strings.TrimSpace(name)) + "|" + strconv.Itoa(group) + "|" + strconv.FormatBool(ensure) + "|" + strings.Join(parts, ",") +} + +func compareParamParity(legacy, shapeParams []paramIR) []string { + var result []string + legacyByKey := map[string]paramIR{} + for _, item := range filterComparableParams(legacy) { + legacyByKey[paramKey(item)] = item + } + shapeByKey := map[string]paramIR{} + for _, item := range filterComparableParams(shapeParams) { + shapeByKey[paramKey(item)] = item + } + if len(legacyByKey) != len(shapeByKey) { + result = append(result, "parameter count mismatch") + } + for key, legacyItem := range legacyByKey { + shapeItem, ok := shapeByKey[key] + if !ok { + result = append(result, "missing parameter in shape: "+legacyItem.Name) + continue + } + if legacyItem.Required != nil && shapeItem.Required != nil && *legacyItem.Required != *shapeItem.Required { + result = append(result, "parameter required mismatch: "+legacyItem.Name) + } + if legacyItem.Cacheable != nil && shapeItem.Cacheable != nil && *legacyItem.Cacheable != *shapeItem.Cacheable { + result = append(result, "parameter cacheable mismatch: "+legacyItem.Name) + } + if legacyItem.QuerySelector != "" && !strings.EqualFold(legacyItem.QuerySelector, shapeItem.QuerySelector) { + result = append(result, "parameter query selector mismatch: "+legacyItem.Name) + } + if legacyItem.URI != "" && !strings.EqualFold(strings.TrimSpace(legacyItem.URI), strings.TrimSpace(shapeItem.URI)) { + result = append(result, "parameter uri mismatch: "+legacyItem.Name) + } + if len(legacyItem.Predicates) != len(shapeItem.Predicates) { + result = append(result, "parameter predicates count mismatch: "+legacyItem.Name) + continue + } + for i := range legacyItem.Predicates { + if legacyItem.Predicates[i] != shapeItem.Predicates[i] { + result = append(result, "parameter predicate mismatch: "+legacyItem.Name) + break + } + } + } + return dedupe(result) +} + +func paramKey(item paramIR) string { + kind := strings.ToLower(strings.TrimSpace(item.Kind)) + in := strings.ToLower(strings.TrimSpace(item.In)) + if kind == "component" { + in = normalizeComponentRef(in) + } + return strings.ToLower(strings.TrimSpace(item.Name)) + "|" + kind + "|" + in +} + +func normalizeComponentRef(in string) string { + in = strings.TrimSpace(strings.TrimPrefix(in, "get:")) + if in == "" { + return in + } + in = strings.TrimPrefix(in, "../") + in = strings.TrimPrefix(in, "./") + in = strings.TrimPrefix(in, "/") + if idx := strings.LastIndex(in, "/"); idx != -1 { + return in[idx+1:] + } + return in +} + +func filterComparableParams(items []paramIR) []paramIR { + if len(items) == 0 { + return nil + } + result := make([]paramIR, 0, len(items)) + for _, item := range items { + kind := strings.ToLower(strings.TrimSpace(item.Kind)) + switch kind { + case "output", "meta", "async": + continue + default: + result = append(result, item) + } + } + return result +} + +func dedupe(items []string) []string { + if len(items) == 0 { + return nil + } + seen := map[string]bool{} + var ret []string + for _, item := range items { + if item == "" || seen[item] { + continue + } + seen[item] = true + ret = append(ret, item) + } + sort.Strings(ret) + return ret +} + +func topIssues(counter map[string]int, limit int) []string { + type pair struct { + Issue string + Count int + } + var list []pair + for issue, count := range counter { + list = append(list, pair{Issue: issue, Count: count}) + } + sort.Slice(list, func(i, j int) bool { + if list[i].Count == list[j].Count { + return list[i].Issue < list[j].Issue + } + return list[i].Count > list[j].Count + }) + if len(list) > limit { + list = list[:limit] + } + var ret []string + for _, item := range list { + ret = append(ret, item.Issue) + } + return ret +} + +func writeIRFile(path string, v parityOutput) { + _ = os.MkdirAll(filepath.Dir(path), 0o755) + writeYAML(path, v) +} + +func writeYAML(path string, v interface{}) { + data, err := yaml.Marshal(v) + if err != nil { + return + } + _ = os.WriteFile(path, data, 0o644) +} diff --git a/repository/shape/platform_parity_types_test.go b/repository/shape/platform_parity_types_test.go new file mode 100644 index 000000000..76341f3f9 --- /dev/null +++ b/repository/shape/platform_parity_types_test.go @@ -0,0 +1,86 @@ +package shape_test + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/viant/datly/repository/shape/plan" + "github.com/viant/datly/repository/shape/typectx" +) + +func TestNormalizeTypeSignature(t *testing.T) { + assert.Equal(t, "struct{ Id int; Name string }", normalizeTypeSignature(" struct{ Id int; Name string } ")) +} + +func TestTypeNameFromDataType(t *testing.T) { + assert.Equal(t, "TvAffiliateStationView", typeNameFromDataType("*tvaffiliatestation.TvAffiliateStationView")) + assert.Equal(t, "Output", typeNameFromDataType("*Output")) + assert.Equal(t, "struct", typeNameFromDataType("struct{Id int}")) +} + +func TestCompareTypeParity(t *testing.T) { + legacy := []typeIR{{ + Name: "TvAffiliateStationView", + DataType: "*tvaffiliatestation.TvAffiliateStationView", + Package: "tvaffiliatestation", + ModulePath: "github.vianttech.com/viant/platform/pkg/platform/tvaffiliatestation", + }} + shapeTypes := []typeIR{{ + Name: "TvAffiliateStationView", + DataType: "*tvaffiliatestation.TvAffiliateStationView", + Package: "tvaffiliatestation", + ModulePath: "github.vianttech.com/viant/platform/pkg/platform/tvaffiliatestation", + }} + assert.Empty(t, compareTypeParity(legacy, shapeTypes)) +} + +func TestNormalizeShapeTypes(t *testing.T) { + planned := &plan.Result{ + TypeContext: &typectx.Context{ + PackagePath: "github.vianttech.com/viant/platform/pkg/platform/tvaffiliatestation", + PackageName: "tvaffiliatestation", + }, + Views: []*plan.View{ + { + Name: "tvAffiliateStation", + Module: "platform/tvaffiliatestation", + SchemaType: "*tvaffiliatestation.TvAffiliateStationView", + Cardinality: "many", + }, + }, + } + actual := normalizeShapeTypes(planned, "/Users/awitas/go/src/github.vianttech.com/viant/platform/dql/platform/tvaffiliatestation/tvaffiliatestation.dql") + if assert.Len(t, actual, 1) { + assert.Equal(t, "TvAffiliateStationView", actual[0].Name) + assert.Equal(t, "github.vianttech.com/viant/platform/pkg/platform/tvaffiliatestation", actual[0].ModulePath) + assert.Equal(t, "Many", actual[0].Cardinality) + } +} + +func TestTypeImports_UsesTypeContextPackagePath(t *testing.T) { + planned := &plan.Result{ + TypeContext: &typectx.Context{ + PackagePath: "github.vianttech.com/viant/platform/pkg/platform/tvaffiliatestation", + PackageName: "tvaffiliatestation", + }, + } + byAlias, byPkg := typeImports(planned) + assert.Equal(t, "github.vianttech.com/viant/platform/pkg/platform/tvaffiliatestation", byAlias["tvaffiliatestation"]) + assert.Equal(t, "github.vianttech.com/viant/platform/pkg/platform/tvaffiliatestation", byPkg["tvaffiliatestation"]) +} + +func TestCompareTypeContextParity(t *testing.T) { + legacy := &typeCtxIR{ + DefaultPackage: "github.vianttech.com/viant/platform/pkg/platform/tvaffiliatestation", + PackageDir: "pkg/platform/tvaffiliatestation", + PackageName: "tvaffiliatestation", + PackagePath: "github.vianttech.com/viant/platform/pkg/platform/tvaffiliatestation", + } + shape := &typeCtxIR{ + DefaultPackage: "github.vianttech.com/viant/platform/pkg/platform/tvaffiliatestation", + PackageDir: "pkg/platform/tvaffiliatestation", + PackageName: "tvaffiliatestation", + PackagePath: "github.vianttech.com/viant/platform/pkg/platform/tvaffiliatestation", + } + assert.Empty(t, compareTypeContextParity(legacy, shape)) +} diff --git a/repository/shape/scan/doc.go b/repository/shape/scan/doc.go new file mode 100644 index 000000000..e1f105775 --- /dev/null +++ b/repository/shape/scan/doc.go @@ -0,0 +1,2 @@ +// Package scan defines scanning responsibilities for struct/DQL inputs. +package scan diff --git a/repository/shape/scan/model.go b/repository/shape/scan/model.go new file mode 100644 index 000000000..357299250 --- /dev/null +++ b/repository/shape/scan/model.go @@ -0,0 +1,33 @@ +package scan + +import ( + "embed" + "reflect" + + "github.com/viant/datly/view/tags" +) + +// Result holds scan output produced from a struct source. +type Result struct { + RootType reflect.Type + EmbedFS *embed.FS + Fields []*Field + ByPath map[string]*Field + ViewFields []*Field + StateFields []*Field +} + +// Field describes one scanned struct field. +type Field struct { + Path string + Name string + Index []int + Type reflect.Type + Tag reflect.StructTag + Anonymous bool + + HasViewTag bool + HasStateTag bool + ViewTag *tags.Tag + StateTag *tags.Tag +} diff --git a/repository/shape/scan/scanner.go b/repository/shape/scan/scanner.go new file mode 100644 index 000000000..255f9cd4b --- /dev/null +++ b/repository/shape/scan/scanner.go @@ -0,0 +1,169 @@ +package scan + +import ( + "context" + "fmt" + "reflect" + "strings" + + "github.com/viant/datly/repository/shape" + "github.com/viant/datly/view/state" + "github.com/viant/datly/view/tags" +) + +// StructScanner scans arbitrary struct types and extracts Datly-relevant tags. +type StructScanner struct{} + +// New returns a Scanner implementation for shape facade. +func New() *StructScanner { + return &StructScanner{} +} + +// Scan implements shape.Scanner. +func (s *StructScanner) Scan(ctx context.Context, source *shape.Source, _ ...shape.ScanOption) (*shape.ScanResult, error) { + if err := ctx.Err(); err != nil { + return nil, err + } + if source == nil { + return nil, shape.ErrNilSource + } + source.EnsureTypeRegistry() + + root, err := resolveRootType(source) + if err != nil { + return nil, err + } + + embedder := resolveEmbedder(source) + result := &Result{ + RootType: root, + EmbedFS: embedder.EmbedFS(), + ByPath: map[string]*Field{}, + } + + if err = s.scanStruct(root, "", nil, embedder, result, map[reflect.Type]bool{}); err != nil { + return nil, err + } + + return &shape.ScanResult{Source: source, Descriptors: result}, nil +} + +func resolveRootType(source *shape.Source) (reflect.Type, error) { + rType, err := source.ResolveRootType() + if err != nil { + return nil, err + } + if rType == nil { + return nil, shape.ErrNilSource + } + for rType.Kind() == reflect.Ptr { + rType = rType.Elem() + } + if rType.Kind() != reflect.Struct { + return nil, fmt.Errorf("shape scan: unsupported source type %v, expected struct", rType) + } + return rType, nil +} + +func resolveEmbedder(source *shape.Source) *state.FSEmbedder { + embedder := state.NewFSEmbedder(nil) + if source.Type != nil { + rType := source.Type + for rType.Kind() == reflect.Ptr { + rType = rType.Elem() + } + embedder.SetType(rType) + return embedder + } + if source.Struct != nil { + rType := reflect.TypeOf(source.Struct) + for rType.Kind() == reflect.Ptr { + rType = rType.Elem() + } + embedder.SetType(rType) + } + return embedder +} + +func (s *StructScanner) scanStruct( + rType reflect.Type, + prefix string, + indexPrefix []int, + embedder *state.FSEmbedder, + result *Result, + visited map[reflect.Type]bool, +) error { + if visited[rType] { + return nil + } + visited[rType] = true + defer delete(visited, rType) + + for i := 0; i < rType.NumField(); i++ { + field := rType.Field(i) + path := field.Name + if prefix != "" { + path = prefix + "." + field.Name + } + combinedIndex := append(append([]int{}, indexPrefix...), field.Index...) + + descriptor := &Field{ + Path: path, + Name: field.Name, + Index: combinedIndex, + Type: field.Type, + Tag: field.Tag, + Anonymous: field.Anonymous, + } + + if hasAny(field.Tag, tags.ViewTag, tags.SQLTag, tags.SQLSummaryTag, tags.LinkOnTag) { + parsed, err := tags.ParseViewTags(field.Tag, embedder.EmbedFS()) + if err != nil { + return fmt.Errorf("shape scan: failed to parse view tags on %s: %w", path, err) + } + descriptor.HasViewTag = true + descriptor.ViewTag = parsed + result.ViewFields = append(result.ViewFields, descriptor) + } + + if hasAny(field.Tag, tags.ParameterTag, tags.SQLTag, tags.PredicateTag, tags.CodecTag, tags.HandlerTag) { + parsed, err := tags.ParseStateTags(field.Tag, embedder.EmbedFS()) + if err != nil { + return fmt.Errorf("shape scan: failed to parse state tags on %s: %w", path, err) + } + descriptor.HasStateTag = true + descriptor.StateTag = parsed + result.StateFields = append(result.StateFields, descriptor) + } + + result.Fields = append(result.Fields, descriptor) + result.ByPath[path] = descriptor + + nextType := field.Type + for nextType.Kind() == reflect.Ptr { + nextType = nextType.Elem() + } + if field.Anonymous && nextType.Kind() == reflect.Struct && !isStdlib(nextType.PkgPath()) { + if err := s.scanStruct(nextType, path, combinedIndex, embedder, result, visited); err != nil { + return err + } + } + } + return nil +} + +func hasAny(tag reflect.StructTag, names ...string) bool { + for _, name := range names { + if _, ok := tag.Lookup(name); ok { + return true + } + } + return false +} + +func isStdlib(pkg string) bool { + if pkg == "" { + return true + } + return !strings.Contains(pkg, ".") +} diff --git a/repository/shape/scan/scanner_test.go b/repository/shape/scan/scanner_test.go new file mode 100644 index 000000000..bf57d5cec --- /dev/null +++ b/repository/shape/scan/scanner_test.go @@ -0,0 +1,83 @@ +package scan + +import ( + "context" + "embed" + "reflect" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/viant/datly/repository/shape" + "github.com/viant/x" +) + +//go:embed testdata/*.sql +var testFS embed.FS + +type embeddedFS struct{} + +func (embeddedFS) EmbedFS() *embed.FS { + return &testFS +} + +type reportRow struct { + ID int + Name string +} + +type reportSource struct { + embeddedFS + Rows []reportRow `view:"rows,table=REPORT,connector=dev" sql:"uri=testdata/report.sql"` + ID int `parameter:"id,kind=query,in=id"` +} + +func TestStructScanner_Scan(t *testing.T) { + scanner := New() + result, err := scanner.Scan(context.Background(), &shape.Source{Struct: &reportSource{}}) + require.NoError(t, err) + require.NotNil(t, result) + + descriptors, ok := DescriptorsFrom(result) + require.True(t, ok) + require.NotNil(t, descriptors) + require.NotNil(t, descriptors.EmbedFS) + assert.Equal(t, reflect.TypeOf(reportSource{}), descriptors.RootType) + + rows := descriptors.ByPath["Rows"] + require.NotNil(t, rows) + require.True(t, rows.HasViewTag) + require.NotNil(t, rows.ViewTag) + assert.Equal(t, "rows", rows.ViewTag.View.Name) + assert.Contains(t, rows.ViewTag.SQL.SQL, "SELECT ID, NAME FROM REPORT") + + idField := descriptors.ByPath["ID"] + require.NotNil(t, idField) + require.True(t, idField.HasStateTag) + require.NotNil(t, idField.StateTag) + require.NotNil(t, idField.StateTag.Parameter) + assert.Equal(t, "id", idField.StateTag.Parameter.Name) + assert.Equal(t, "query", idField.StateTag.Parameter.Kind) + assert.Equal(t, "id", idField.StateTag.Parameter.In) +} + +func TestStructScanner_Scan_InvalidSource(t *testing.T) { + scanner := New() + _, err := scanner.Scan(context.Background(), &shape.Source{Struct: 1}) + require.Error(t, err) + assert.Contains(t, err.Error(), "expected struct") +} + +func TestStructScanner_Scan_WithRegistryType(t *testing.T) { + scanner := New() + registry := x.NewRegistry() + registry.Register(x.NewType(reflect.TypeOf(reportSource{}))) + result, err := scanner.Scan(context.Background(), &shape.Source{ + TypeName: "github.com/viant/datly/repository/shape/scan.reportSource", + TypeRegistry: registry, + }) + require.NoError(t, err) + descriptors, ok := DescriptorsFrom(result) + require.True(t, ok) + assert.Equal(t, reflect.TypeOf(reportSource{}), descriptors.RootType) +} diff --git a/repository/shape/scan/spec.go b/repository/shape/scan/spec.go new file mode 100644 index 000000000..69e3eb185 --- /dev/null +++ b/repository/shape/scan/spec.go @@ -0,0 +1,16 @@ +package scan + +import "github.com/viant/datly/repository/shape" + +// ShapeSpecKind implements shape.ScanSpec. +func (r *Result) ShapeSpecKind() string { return "scan" } + +// DescriptorsFrom extracts the typed scan result from a ScanResult. +// Returns (nil, false) when a is nil or contains an unexpected concrete type. +func DescriptorsFrom(a *shape.ScanResult) (*Result, bool) { + if a == nil { + return nil, false + } + r, ok := a.Descriptors.(*Result) + return r, ok && r != nil +} diff --git a/repository/shape/scan/testdata/report.sql b/repository/shape/scan/testdata/report.sql new file mode 100644 index 000000000..68f0f3b34 --- /dev/null +++ b/repository/shape/scan/testdata/report.sql @@ -0,0 +1 @@ +SELECT ID, NAME FROM REPORT diff --git a/repository/shape/shape.go b/repository/shape/shape.go new file mode 100644 index 000000000..5f7f766d8 --- /dev/null +++ b/repository/shape/shape.go @@ -0,0 +1,198 @@ +package shape + +import "context" + +type ( + CompileMixedMode string + CompileUnknownNonReadMode string + CompileProfile string + CompileColumnDiscoveryMode string + + // Scanner discovers shape descriptors from Source. + Scanner interface { + Scan(ctx context.Context, source *Source, opts ...ScanOption) (*ScanResult, error) + } + + // Planner normalizes discovered descriptors into execution plan. + Planner interface { + Plan(ctx context.Context, scan *ScanResult, opts ...PlanOption) (*PlanResult, error) + } + + // Loader materializes runtime artifacts from normalized plan. + Loader interface { + LoadViews(ctx context.Context, plan *PlanResult, opts ...LoadOption) (*ViewArtifacts, error) + LoadComponent(ctx context.Context, plan *PlanResult, opts ...LoadOption) (*ComponentArtifact, error) + } + + // DQLCompiler compiles DQL source directly into a shape plan. + DQLCompiler interface { + Compile(ctx context.Context, source *Source, opts ...CompileOption) (*PlanResult, error) + } + + // RuntimeRegistrar optionally registers loaded artifacts in runtime services. + RuntimeRegistrar interface { + RegisterViews(ctx context.Context, artifacts *ViewArtifacts) error + RegisterComponent(ctx context.Context, artifacts *ComponentArtifact) error + } + + ScanOptions struct{} + PlanOptions struct{} + LoadOptions struct{} + CompileOptions struct { + Strict bool + Profile CompileProfile + MixedMode CompileMixedMode + UnknownNonReadMode CompileUnknownNonReadMode + ColumnDiscoveryMode CompileColumnDiscoveryMode + DQLPathMarker string + RoutesRelativePath string + TypePackageDir string + TypePackageName string + TypePackagePath string + InferTypeContext *bool + } + + ScanOption func(*ScanOptions) + PlanOption func(*PlanOptions) + LoadOption func(*LoadOptions) + CompileOption func(*CompileOptions) +) + +const ( + CompileMixedModeExecWins CompileMixedMode = "exec_wins" + CompileMixedModeReadWins CompileMixedMode = "read_wins" + CompileMixedModeErrorOnMixed CompileMixedMode = "error_on_mixed" + + CompileUnknownNonReadWarn CompileUnknownNonReadMode = "warn" + CompileUnknownNonReadError CompileUnknownNonReadMode = "error" + + CompileProfileCompat CompileProfile = "compat" + CompileProfileStrict CompileProfile = "strict" + + CompileColumnDiscoveryAuto CompileColumnDiscoveryMode = "auto" + CompileColumnDiscoveryOn CompileColumnDiscoveryMode = "on" + CompileColumnDiscoveryOff CompileColumnDiscoveryMode = "off" +) + +// Engine is a thin facade over scan -> plan -> load pipeline. +type Engine struct { + options *Options +} + +// New creates an Engine facade. +func New(opts ...Option) *Engine { + return &Engine{options: NewOptions(opts...)} +} + +// LoadViews is a package-level helper for struct source view loading. +func LoadViews(ctx context.Context, src any, opts ...Option) (*ViewArtifacts, error) { + return New(opts...).LoadViews(ctx, src) +} + +// LoadComponent is a package-level helper for struct source component loading. +func LoadComponent(ctx context.Context, src any, opts ...Option) (*ComponentArtifact, error) { + return New(opts...).LoadComponent(ctx, src) +} + +// LoadDQLViews is a package-level helper for DQL source view loading. +func LoadDQLViews(ctx context.Context, dql string, opts ...Option) (*ViewArtifacts, error) { + return New(opts...).LoadDQLViews(ctx, dql) +} + +// LoadDQLComponent is a package-level helper for DQL source component loading. +func LoadDQLComponent(ctx context.Context, dql string, opts ...Option) (*ComponentArtifact, error) { + return New(opts...).LoadDQLComponent(ctx, dql) +} + +// LoadViews executes scan -> plan -> load for struct source. +func (e *Engine) LoadViews(ctx context.Context, src any) (*ViewArtifacts, error) { + source, err := e.structSource(src) + if err != nil { + return nil, err + } + plan, err := e.scanAndPlan(ctx, source) + if err != nil { + return nil, err + } + if e.options.Loader == nil { + return nil, ErrLoaderNotConfigured + } + return e.options.Loader.LoadViews(ctx, plan) +} + +// LoadComponent executes scan -> plan -> load for struct source. +func (e *Engine) LoadComponent(ctx context.Context, src any) (*ComponentArtifact, error) { + source, err := e.structSource(src) + if err != nil { + return nil, err + } + plan, err := e.scanAndPlan(ctx, source) + if err != nil { + return nil, err + } + if e.options.Loader == nil { + return nil, ErrLoaderNotConfigured + } + return e.options.Loader.LoadComponent(ctx, plan) +} + +// LoadDQLViews executes compile -> load for DQL source. +func (e *Engine) LoadDQLViews(ctx context.Context, dql string) (*ViewArtifacts, error) { + source, err := e.dqlSource(dql) + if err != nil { + return nil, err + } + plan, err := e.compile(ctx, source) + if err != nil { + return nil, err + } + if e.options.Loader == nil { + return nil, ErrLoaderNotConfigured + } + return e.options.Loader.LoadViews(ctx, plan) +} + +// LoadDQLComponent executes compile -> load for DQL source. +func (e *Engine) LoadDQLComponent(ctx context.Context, dql string) (*ComponentArtifact, error) { + source, err := e.dqlSource(dql) + if err != nil { + return nil, err + } + plan, err := e.compile(ctx, source) + if err != nil { + return nil, err + } + if e.options.Loader == nil { + return nil, ErrLoaderNotConfigured + } + return e.options.Loader.LoadComponent(ctx, plan) +} + +func (e *Engine) compile(ctx context.Context, source *Source) (*PlanResult, error) { + if e.options.Compiler == nil { + return nil, ErrCompilerNotConfigured + } + return e.options.Compiler.Compile( + ctx, + source, + WithCompileStrict(e.options.Strict), + WithCompileProfile(e.options.CompileProfile), + WithMixedMode(e.options.CompileMixedMode), + WithUnknownNonReadMode(e.options.UnknownNonReadMode), + WithColumnDiscoveryMode(e.options.ColumnDiscoveryMode), + ) +} + +func (e *Engine) scanAndPlan(ctx context.Context, source *Source) (*PlanResult, error) { + if e.options.Scanner == nil { + return nil, ErrScannerNotConfigured + } + if e.options.Planner == nil { + return nil, ErrPlannerNotConfigured + } + scanResult, err := e.options.Scanner.Scan(ctx, source) + if err != nil { + return nil, err + } + return e.options.Planner.Plan(ctx, scanResult) +} diff --git a/repository/shape/source.go b/repository/shape/source.go new file mode 100644 index 000000000..e408c2163 --- /dev/null +++ b/repository/shape/source.go @@ -0,0 +1,39 @@ +package shape + +import ( + "reflect" + "strings" + + "github.com/viant/x" +) + +func (e *Engine) structSource(src any) (*Source, error) { + if src == nil { + return nil, ErrNilSource + } + rType := reflect.TypeOf(src) + for rType.Kind() == reflect.Ptr { + rType = rType.Elem() + } + registry := x.NewRegistry() + registry.Register(x.NewType(rType)) + return &Source{ + Name: e.options.Name, + Struct: src, + Type: rType, + TypeName: x.NewType(rType).Key(), + TypeRegistry: registry, + DQL: "", + }, nil +} + +func (e *Engine) dqlSource(dql string) (*Source, error) { + dql = strings.TrimSpace(dql) + if dql == "" { + return nil, ErrNilDQL + } + return &Source{ + Name: e.options.Name, + DQL: dql, + }, nil +} diff --git a/repository/shape/source_type.go b/repository/shape/source_type.go new file mode 100644 index 000000000..51bc7132e --- /dev/null +++ b/repository/shape/source_type.go @@ -0,0 +1,56 @@ +package shape + +import ( + "fmt" + "reflect" + "strings" + + "github.com/viant/x" +) + +// ResolveRootType resolves source root type from explicit Type, Struct, or viant/x registry. +func (s *Source) ResolveRootType() (reflect.Type, error) { + if s == nil { + return nil, ErrNilSource + } + if s.Type != nil { + return unwrapPtr(s.Type), nil + } + if s.Struct != nil { + return unwrapPtr(reflect.TypeOf(s.Struct)), nil + } + key := strings.TrimSpace(s.TypeName) + if key == "" || s.TypeRegistry == nil { + return nil, ErrNilSource + } + aType := s.TypeRegistry.Lookup(key) + if aType == nil || aType.Type == nil { + return nil, fmt.Errorf("shape source: type %q not found in registry", key) + } + return unwrapPtr(aType.Type), nil +} + +// EnsureTypeRegistry returns source registry ensuring root type is registered when available. +func (s *Source) EnsureTypeRegistry() *x.Registry { + if s == nil { + return nil + } + if s.TypeRegistry == nil { + s.TypeRegistry = x.NewRegistry() + } + if rType, err := s.ResolveRootType(); err == nil && rType != nil { + t := x.NewType(rType) + if strings.TrimSpace(s.TypeName) == "" { + s.TypeName = t.Key() + } + s.TypeRegistry.Register(t) + } + return s.TypeRegistry +} + +func unwrapPtr(rType reflect.Type) reflect.Type { + for rType != nil && rType.Kind() == reflect.Ptr { + rType = rType.Elem() + } + return rType +} diff --git a/repository/shape/source_type_test.go b/repository/shape/source_type_test.go new file mode 100644 index 000000000..3118f8fed --- /dev/null +++ b/repository/shape/source_type_test.go @@ -0,0 +1,33 @@ +package shape + +import ( + "reflect" + "testing" + + "github.com/stretchr/testify/require" + "github.com/viant/x" +) + +type sampleShape struct { + ID int +} + +func TestSource_ResolveRootType_FromRegistry(t *testing.T) { + registry := x.NewRegistry() + registry.Register(x.NewType(reflect.TypeOf(sampleShape{}))) + src := &Source{ + TypeName: "github.com/viant/datly/repository/shape.sampleShape", + TypeRegistry: registry, + } + rType, err := src.ResolveRootType() + require.NoError(t, err) + require.Equal(t, reflect.TypeOf(sampleShape{}), rType) +} + +func TestSource_EnsureTypeRegistry_RegistersRoot(t *testing.T) { + src := &Source{Struct: &sampleShape{}} + registry := src.EnsureTypeRegistry() + require.NotNil(t, registry) + require.NotEmpty(t, src.TypeName) + require.NotNil(t, registry.Lookup(src.TypeName)) +} diff --git a/repository/shape/typectx/context.go b/repository/shape/typectx/context.go new file mode 100644 index 000000000..072e22b9f --- /dev/null +++ b/repository/shape/typectx/context.go @@ -0,0 +1,89 @@ +package typectx + +import ( + "path" + "strings" +) + +// ValidationIssue captures context consistency problems. +type ValidationIssue struct { + Field string + Message string +} + +// Normalize trims and canonicalizes context fields. +func Normalize(input *Context) *Context { + if input == nil { + return nil + } + ret := &Context{ + DefaultPackage: strings.TrimSpace(input.DefaultPackage), + PackageDir: cleanSlashes(strings.TrimSpace(input.PackageDir)), + PackageName: strings.TrimSpace(input.PackageName), + PackagePath: cleanSlashes(strings.TrimSpace(input.PackagePath)), + } + if ret.PackageName == "" { + if ret.PackagePath != "" { + ret.PackageName = path.Base(ret.PackagePath) + } else if ret.PackageDir != "" { + ret.PackageName = path.Base(ret.PackageDir) + } + } + if ret.DefaultPackage == "" && ret.PackagePath != "" { + ret.DefaultPackage = ret.PackagePath + } + for _, item := range input.Imports { + pkg := cleanSlashes(strings.TrimSpace(item.Package)) + if pkg == "" { + continue + } + alias := strings.TrimSpace(item.Alias) + if alias == "" { + alias = path.Base(pkg) + } + ret.Imports = append(ret.Imports, Import{ + Alias: alias, + Package: pkg, + }) + } + if ret.DefaultPackage == "" && + len(ret.Imports) == 0 && + ret.PackageDir == "" && + ret.PackageName == "" && + ret.PackagePath == "" { + return nil + } + return ret +} + +// Validate checks context consistency. +func Validate(ctx *Context) []ValidationIssue { + ctx = Normalize(ctx) + if ctx == nil { + return nil + } + var result []ValidationIssue + if strings.Contains(ctx.PackageName, "/") { + result = append(result, ValidationIssue{ + Field: "PackageName", + Message: "package name must not contain path separators", + }) + } + if ctx.PackagePath != "" && strings.Contains(ctx.PackagePath, ".") { + base := path.Base(ctx.PackagePath) + if ctx.PackageName != "" && base != ctx.PackageName { + result = append(result, ValidationIssue{ + Field: "PackagePath", + Message: "package path basename differs from package name", + }) + } + } + return result +} + +func cleanSlashes(value string) string { + value = strings.ReplaceAll(value, "\\", "/") + value = strings.TrimSpace(value) + value = strings.Trim(value, "/") + return value +} diff --git a/repository/shape/typectx/context_test.go b/repository/shape/typectx/context_test.go new file mode 100644 index 000000000..325709448 --- /dev/null +++ b/repository/shape/typectx/context_test.go @@ -0,0 +1,31 @@ +package typectx + +import "testing" + +func TestNormalize_FillsPackageFields(t *testing.T) { + ctx := Normalize(&Context{ + PackageDir: "pkg/platform/taxonomy", + PackagePath: "github.vianttech.com/viant/platform/pkg/platform/taxonomy", + }) + if ctx == nil { + t.Fatalf("expected normalized context") + } + if ctx.PackageName != "taxonomy" { + t.Fatalf("expected package name taxonomy, got %q", ctx.PackageName) + } + if ctx.DefaultPackage != "github.vianttech.com/viant/platform/pkg/platform/taxonomy" { + t.Fatalf("expected default package from package path, got %q", ctx.DefaultPackage) + } +} + +func TestValidate_DetectsInvalidPackageName(t *testing.T) { + issues := Validate(&Context{ + PackageName: "platform/taxonomy", + }) + if len(issues) == 0 { + t.Fatalf("expected validation issue") + } + if issues[0].Field != "PackageName" { + t.Fatalf("expected PackageName issue, got %q", issues[0].Field) + } +} diff --git a/repository/shape/typectx/model.go b/repository/shape/typectx/model.go new file mode 100644 index 000000000..acc03ca18 --- /dev/null +++ b/repository/shape/typectx/model.go @@ -0,0 +1,32 @@ +package typectx + +// Import describes one package alias import for DQL/type resolution. +type Import struct { + Alias string `json:",omitempty" yaml:",omitempty"` + Package string `json:",omitempty" yaml:",omitempty"` +} + +// Context captures default package and imports used for type resolution. +type Context struct { + DefaultPackage string `json:",omitempty" yaml:",omitempty"` + Imports []Import `json:",omitempty" yaml:",omitempty"` + PackageDir string `json:",omitempty" yaml:",omitempty"` + PackageName string `json:",omitempty" yaml:",omitempty"` + PackagePath string `json:",omitempty" yaml:",omitempty"` +} + +// Provenance tracks where a resolved type came from. +type Provenance struct { + Package string `json:",omitempty" yaml:",omitempty"` + File string `json:",omitempty" yaml:",omitempty"` + Kind string `json:",omitempty" yaml:",omitempty"` // builtin, resource_type, registry, ast_type +} + +// Resolution captures one resolved type expression and its provenance. +type Resolution struct { + Expression string `json:",omitempty" yaml:",omitempty"` + Target string `json:",omitempty" yaml:",omitempty"` + ResolvedKey string `json:",omitempty" yaml:",omitempty"` + MatchKind string `json:",omitempty" yaml:",omitempty"` // exact, alias_import, qualified, default_package, import_package, global_unique + Provenance Provenance `json:",omitempty" yaml:",omitempty"` +} diff --git a/repository/shape/typectx/resolver.go b/repository/shape/typectx/resolver.go new file mode 100644 index 000000000..892c9ef6f --- /dev/null +++ b/repository/shape/typectx/resolver.go @@ -0,0 +1,273 @@ +package typectx + +import ( + "fmt" + "sort" + "strings" + + "github.com/viant/x" +) + +// AmbiguityError reports multiple matching type candidates for a type expression. +type AmbiguityError struct { + Expression string + Candidates []string +} + +func (e *AmbiguityError) Error() string { + return fmt.Sprintf("ambiguous type %q: candidates=%s", e.Expression, strings.Join(e.Candidates, ",")) +} + +// Resolver resolves cast/tag type expressions against viant/x registry using type context. +type Resolver struct { + registry *x.Registry + context *Context + provenance map[string]Provenance +} + +// NewResolver creates a type resolver. +func NewResolver(registry *x.Registry, context *Context) *Resolver { + return NewResolverWithProvenance(registry, context, nil) +} + +// NewResolverWithProvenance creates a type resolver with optional registry-key provenance map. +func NewResolverWithProvenance(registry *x.Registry, context *Context, provenance map[string]Provenance) *Resolver { + return &Resolver{ + registry: registry, + context: normalizeContext(context), + provenance: cloneProvenance(provenance), + } +} + +// Resolve resolves type expression to registry key. It returns ("", nil) when unresolved. +func (r *Resolver) Resolve(typeExpr string) (string, error) { + resolved, err := r.ResolveWithProvenance(typeExpr) + if err != nil || resolved == nil { + return "", err + } + return resolved.ResolvedKey, nil +} + +// ResolveWithProvenance resolves expression and returns provenance details. +// It returns (nil, nil) when unresolved. +func (r *Resolver) ResolveWithProvenance(typeExpr string) (*Resolution, error) { + if r == nil || r.registry == nil { + return nil, nil + } + base := normalizeLookupKey(typeExpr) + if base == "" { + return nil, nil + } + + // Exact type key (builtins or fully-qualified package.Type) + if r.registry.Lookup(base) != nil { + return r.newResolution(typeExpr, "", base, "exact"), nil + } + + prefix, baseName, alias, qualified := splitQualified(base) + if qualified { + if prefix == "" || baseName == "" { + return nil, nil + } + if alias { + pkg := r.aliasPackage(prefix) + if pkg == "" { + return nil, nil + } + candidate := pkg + "." + baseName + if r.registry.Lookup(candidate) == nil { + return nil, nil + } + return r.newResolution(typeExpr, "", candidate, "alias_import"), nil + } + // fully qualified package path.Type + if r.registry.Lookup(base) != nil { + return r.newResolution(typeExpr, "", base, "qualified"), nil + } + return nil, nil + } + + // Unqualified resolution: default package, then imports; if still unresolved, + // fallback to unique global name match. + candidates := r.unqualifiedCandidates(baseName) + if len(candidates) == 1 { + return r.newResolution(typeExpr, "", candidates[0].key, candidates[0].matchKind), nil + } + if len(candidates) > 1 { + keys := make([]string, 0, len(candidates)) + for _, candidate := range candidates { + keys = append(keys, candidate.key) + } + sort.Strings(keys) + return nil, &AmbiguityError{Expression: typeExpr, Candidates: keys} + } + return nil, nil +} + +func (r *Resolver) aliasPackage(alias string) string { + alias = strings.TrimSpace(alias) + if alias == "" || r.context == nil { + return "" + } + for _, item := range r.context.Imports { + if item.Alias == alias { + return item.Package + } + } + if r.context.PackageName != "" && r.context.PackagePath != "" && r.context.PackageName == alias { + return r.context.PackagePath + } + return "" +} + +type candidate struct { + key string + matchKind string +} + +func (r *Resolver) unqualifiedCandidates(typeName string) []candidate { + if typeName == "" { + return nil + } + seen := map[string]bool{} + var result []candidate + + for _, scoped := range r.searchPackages() { + pkg := scoped.pkg + key := pkg + "." + typeName + if seen[key] { + continue + } + seen[key] = true + if r.registry.Lookup(key) != nil { + result = append(result, candidate{key: key, matchKind: scoped.matchKind}) + } + } + if len(result) > 0 { + return result + } + + // Global unique fallback by suffix ".TypeName" or exact built-in. + for _, key := range r.registry.Keys() { + if key == typeName || strings.HasSuffix(key, "."+typeName) { + if seen[key] { + continue + } + seen[key] = true + result = append(result, candidate{key: key, matchKind: "global_unique"}) + } + } + return result +} + +type scopedPackage struct { + pkg string + matchKind string +} + +func (r *Resolver) searchPackages() []scopedPackage { + if r.context == nil { + return nil + } + seen := map[string]bool{} + var result []scopedPackage + appendPkg := func(pkg, matchKind string) { + pkg = strings.TrimSpace(pkg) + if pkg == "" || seen[pkg] { + return + } + seen[pkg] = true + result = append(result, scopedPackage{pkg: pkg, matchKind: matchKind}) + } + appendPkg(r.context.PackagePath, "package_path") + appendPkg(r.context.DefaultPackage, "default_package") + for _, item := range r.context.Imports { + appendPkg(item.Package, "import_package") + } + return result +} + +func (r *Resolver) newResolution(expression, target, key, matchKind string) *Resolution { + if key == "" { + return nil + } + resolution := &Resolution{ + Expression: strings.TrimSpace(expression), + Target: strings.TrimSpace(target), + ResolvedKey: key, + MatchKind: matchKind, + Provenance: r.lookupProvenance(key), + } + return resolution +} + +func (r *Resolver) lookupProvenance(key string) Provenance { + prov := Provenance{ + Package: packageOf(key), + Kind: "registry", + } + if built, ok := r.provenance[key]; ok { + if built.Package != "" { + prov.Package = built.Package + } + if built.File != "" { + prov.File = built.File + } + if built.Kind != "" { + prov.Kind = built.Kind + } + } + return prov +} + +func cloneProvenance(input map[string]Provenance) map[string]Provenance { + if len(input) == 0 { + return nil + } + result := make(map[string]Provenance, len(input)) + for k, v := range input { + result[k] = v + } + return result +} + +func packageOf(key string) string { + index := strings.LastIndex(key, ".") + if index == -1 { + return "" + } + return key[:index] +} + +func normalizeContext(input *Context) *Context { + return Normalize(input) +} + +func splitQualified(value string) (prefix string, name string, alias bool, qualified bool) { + index := strings.LastIndex(value, ".") + if index == -1 { + return "", value, false, false + } + prefix = strings.TrimSpace(value[:index]) + name = strings.TrimSpace(value[index+1:]) + if prefix == "" || name == "" { + return "", "", false, false + } + qualified = true + alias = !strings.Contains(prefix, "/") + return prefix, name, alias, qualified +} + +func normalizeLookupKey(typeExpr string) string { + value := strings.TrimSpace(typeExpr) + for { + switch { + case strings.HasPrefix(value, "*"): + value = strings.TrimPrefix(value, "*") + case strings.HasPrefix(value, "[]"): + value = strings.TrimPrefix(value, "[]") + default: + return strings.TrimSpace(value) + } + } +} diff --git a/repository/shape/typectx/resolver_matrix_test.go b/repository/shape/typectx/resolver_matrix_test.go new file mode 100644 index 000000000..473ba368b --- /dev/null +++ b/repository/shape/typectx/resolver_matrix_test.go @@ -0,0 +1,86 @@ +package typectx + +import ( + "reflect" + "testing" + + "github.com/stretchr/testify/require" + "github.com/viant/x" +) + +type matrixOrderDefault struct{} +type matrixOrderImport struct{} +type matrixOrderPkgPath struct{} +type matrixOrderAliasImport struct{} + +func TestResolver_ResolutionMatrix(t *testing.T) { + reg := x.NewRegistry() + reg.Register(x.NewType(reflect.TypeOf(matrixOrderDefault{}), x.WithPkgPath("github.com/acme/default"), x.WithName("Order"))) + reg.Register(x.NewType(reflect.TypeOf(matrixOrderImport{}), x.WithPkgPath("github.com/acme/imported"), x.WithName("ImportedOrder"))) + reg.Register(x.NewType(reflect.TypeOf(matrixOrderPkgPath{}), x.WithPkgPath("github.com/acme/pkgpath"), x.WithName("Order"))) + reg.Register(x.NewType(reflect.TypeOf(matrixOrderAliasImport{}), x.WithPkgPath("github.com/acme/alias/import"), x.WithName("Order"))) + + testCases := []struct { + name string + context *Context + expr string + wantKey string + ambiguous bool + }{ + { + name: "only default/imports", + context: &Context{ + DefaultPackage: "github.com/acme/default", + Imports: []Import{{Alias: "imp", Package: "github.com/acme/imported"}}, + }, + expr: "Order", + wantKey: "github.com/acme/default.Order", + }, + { + name: "only package triple", + context: &Context{ + PackagePath: "github.com/acme/pkgpath", + PackageName: "pkgpath", + PackageDir: "pkg/pkgpath", + }, + expr: "Order", + wantKey: "github.com/acme/pkgpath.Order", + }, + { + name: "default and package path conflict", + context: &Context{ + DefaultPackage: "github.com/acme/default", + PackagePath: "github.com/acme/pkgpath", + PackageName: "pkgpath", + }, + expr: "Order", + ambiguous: true, + }, + { + name: "alias import wins over package-name fallback", + context: &Context{ + PackagePath: "github.com/acme/pkgpath", + PackageName: "same", + Imports: []Import{{Alias: "same", Package: "github.com/acme/alias/import"}}, + }, + expr: "same.Order", + wantKey: "github.com/acme/alias/import.Order", + }, + } + + for _, testCase := range testCases { + t.Run(testCase.name, func(t *testing.T) { + resolver := NewResolver(reg, testCase.context) + key, err := resolver.Resolve(testCase.expr) + if testCase.ambiguous { + require.Error(t, err) + _, ok := err.(*AmbiguityError) + require.True(t, ok) + require.Empty(t, key) + return + } + require.NoError(t, err) + require.Equal(t, testCase.wantKey, key) + }) + } +} diff --git a/repository/shape/typectx/resolver_memfs_test.go b/repository/shape/typectx/resolver_memfs_test.go new file mode 100644 index 000000000..cc90b4700 --- /dev/null +++ b/repository/shape/typectx/resolver_memfs_test.go @@ -0,0 +1,116 @@ +package typectx + +import ( + "context" + "path" + "testing" + "testing/fstest" + + "github.com/stretchr/testify/require" + "github.com/viant/x" + xast "github.com/viant/x/loader/ast" +) + +func TestResolver_MemFS_DefaultPackageResolution(t *testing.T) { + resolver := memFSResolver(t, baseTypeMapFS(), []string{"root/perf"}, &Context{ + DefaultPackage: "example.com/acme/perf", + }) + + key, err := resolver.Resolve("Order") + require.NoError(t, err) + require.Equal(t, "example.com/acme/perf.Order", key) +} + +func TestResolver_MemFS_AliasImportResolution(t *testing.T) { + resolver := memFSResolver(t, baseTypeMapFS(), []string{"root/perf"}, &Context{ + Imports: []Import{ + {Alias: "pf", Package: "example.com/acme/perf"}, + }, + }) + + key, err := resolver.Resolve("pf.Order") + require.NoError(t, err) + require.Equal(t, "example.com/acme/perf.Order", key) +} + +func TestResolver_MemFS_AmbiguityDetection(t *testing.T) { + resolver := memFSResolver(t, baseTypeMapFS(), []string{"root/perf", "root/shared"}, &Context{ + Imports: []Import{ + {Alias: "pf", Package: "example.com/acme/perf"}, + {Alias: "sh", Package: "example.com/acme/shared"}, + }, + }) + + key, err := resolver.Resolve("Fee") + require.Empty(t, key) + require.Error(t, err) + amb, ok := err.(*AmbiguityError) + require.True(t, ok) + require.Equal(t, []string{ + "example.com/acme/perf.Fee", + "example.com/acme/shared.Fee", + }, amb.Candidates) +} + +func TestResolver_MemFS_ProvenanceCapture(t *testing.T) { + resolver := memFSResolver(t, baseTypeMapFS(), []string{"root/perf"}, &Context{ + DefaultPackage: "example.com/acme/perf", + }) + + resolved, err := resolver.ResolveWithProvenance("Order") + require.NoError(t, err) + require.NotNil(t, resolved) + require.Equal(t, "example.com/acme/perf.Order", resolved.ResolvedKey) + require.Equal(t, "default_package", resolved.MatchKind) + require.Equal(t, "ast_type", resolved.Provenance.Kind) + require.Equal(t, "example.com/acme/perf", resolved.Provenance.Package) + require.Equal(t, "root/perf/types.go", resolved.Provenance.File) +} + +func memFSResolver(t *testing.T, fsys fstest.MapFS, packageDirs []string, ctx *Context) *Resolver { + t.Helper() + registry := x.NewRegistry() + provenance := map[string]Provenance{} + for _, dir := range packageDirs { + pkg, err := xast.LoadPackageFS(context.Background(), fsys, dir) + require.NoError(t, err) + + fileByType := map[string]string{} + for _, file := range pkg.Files { + if file == nil { + continue + } + for _, item := range file.Types { + if item == nil || item.Name == "" { + continue + } + fileByType[item.Name] = path.Join(dir, file.Name) + } + } + for _, item := range pkg.Types { + if item == nil || item.Name == "" { + continue + } + aType := &x.Type{ + Name: item.Name, + PkgPath: pkg.PkgPath, + } + registry.Register(aType) + provenance[aType.Key()] = Provenance{ + Package: pkg.PkgPath, + File: fileByType[item.Name], + Kind: "ast_type", + } + } + } + return NewResolverWithProvenance(registry, ctx, provenance) +} + +func baseTypeMapFS() fstest.MapFS { + return fstest.MapFS{ + "root/go.mod": &fstest.MapFile{Data: []byte("module example.com/acme\n\ngo 1.23\n")}, + "root/perf/types.go": &fstest.MapFile{Data: []byte("package perf\n\ntype Order struct{}\ntype Fee struct{}\n")}, + "root/shared/types.go": &fstest.MapFile{Data: []byte("package shared\n\ntype Fee struct{}\n")}, + "root/ignore/other.txt": &fstest.MapFile{Data: []byte("skip")}, + } +} diff --git a/repository/shape/typectx/resolver_test.go b/repository/shape/typectx/resolver_test.go new file mode 100644 index 000000000..632a785da --- /dev/null +++ b/repository/shape/typectx/resolver_test.go @@ -0,0 +1,114 @@ +package typectx + +import ( + "reflect" + "testing" + + "github.com/stretchr/testify/require" + "github.com/viant/x" +) + +type resolveFeeA struct{} +type resolveFeeB struct{} +type resolveOrder struct{} + +func TestResolver_Resolve_Unqualified_DefaultPackage(t *testing.T) { + reg := x.NewRegistry() + reg.Register(x.NewType(reflect.TypeOf(resolveOrder{}), x.WithPkgPath("github.com/acme/mdp/performance"), x.WithName("Order"))) + resolver := NewResolver(reg, &Context{DefaultPackage: "github.com/acme/mdp/performance"}) + + key, err := resolver.Resolve("Order") + require.NoError(t, err) + require.Equal(t, "github.com/acme/mdp/performance.Order", key) +} + +func TestResolver_Resolve_AliasQualified(t *testing.T) { + reg := x.NewRegistry() + reg.Register(x.NewType(reflect.TypeOf(resolveOrder{}), x.WithPkgPath("github.com/acme/mdp/performance"), x.WithName("Order"))) + resolver := NewResolver(reg, &Context{ + Imports: []Import{ + {Alias: "perf", Package: "github.com/acme/mdp/performance"}, + }, + }) + + key, err := resolver.Resolve("perf.Order") + require.NoError(t, err) + require.Equal(t, "github.com/acme/mdp/performance.Order", key) +} + +func TestResolver_Resolve_Unqualified_Ambiguous(t *testing.T) { + reg := x.NewRegistry() + reg.Register(x.NewType(reflect.TypeOf(resolveFeeA{}), x.WithPkgPath("github.com/acme/alpha"), x.WithName("Fee"))) + reg.Register(x.NewType(reflect.TypeOf(resolveFeeB{}), x.WithPkgPath("github.com/acme/beta"), x.WithName("Fee"))) + resolver := NewResolver(reg, &Context{ + Imports: []Import{ + {Alias: "a", Package: "github.com/acme/alpha"}, + {Alias: "b", Package: "github.com/acme/beta"}, + }, + }) + + key, err := resolver.Resolve("Fee") + require.Empty(t, key) + require.Error(t, err) + amb, ok := err.(*AmbiguityError) + require.True(t, ok) + require.Equal(t, []string{ + "github.com/acme/alpha.Fee", + "github.com/acme/beta.Fee", + }, amb.Candidates) +} + +func TestResolver_Resolve_Unqualified_GlobalUniqueFallback(t *testing.T) { + reg := x.NewRegistry() + reg.Register(x.NewType(reflect.TypeOf(resolveOrder{}), x.WithPkgPath("github.com/acme/shared"), x.WithName("Order"))) + resolver := NewResolver(reg, nil) + + key, err := resolver.Resolve("Order") + require.NoError(t, err) + require.Equal(t, "github.com/acme/shared.Order", key) +} + +func TestResolver_ResolveWithProvenance(t *testing.T) { + reg := x.NewRegistry() + reg.Register(x.NewType(reflect.TypeOf(resolveOrder{}), x.WithPkgPath("github.com/acme/mdp/performance"), x.WithName("Order"))) + resolver := NewResolverWithProvenance(reg, &Context{DefaultPackage: "github.com/acme/mdp/performance"}, map[string]Provenance{ + "github.com/acme/mdp/performance.Order": { + Package: "github.com/acme/mdp/performance", + File: "/repo/mdp/performance/order.go", + Kind: "resource_type", + }, + }) + + resolved, err := resolver.ResolveWithProvenance("Order") + require.NoError(t, err) + require.NotNil(t, resolved) + require.Equal(t, "github.com/acme/mdp/performance.Order", resolved.ResolvedKey) + require.Equal(t, "default_package", resolved.MatchKind) + require.Equal(t, "/repo/mdp/performance/order.go", resolved.Provenance.File) + require.Equal(t, "resource_type", resolved.Provenance.Kind) +} + +func TestResolver_Resolve_Unqualified_PackagePath(t *testing.T) { + reg := x.NewRegistry() + reg.Register(x.NewType(reflect.TypeOf(resolveOrder{}), x.WithPkgPath("github.com/acme/mdp/performance"), x.WithName("Order"))) + resolver := NewResolver(reg, &Context{PackagePath: "github.com/acme/mdp/performance"}) + + resolved, err := resolver.ResolveWithProvenance("Order") + require.NoError(t, err) + require.NotNil(t, resolved) + require.Equal(t, "github.com/acme/mdp/performance.Order", resolved.ResolvedKey) + require.Equal(t, "package_path", resolved.MatchKind) +} + +func TestResolver_Resolve_Qualified_PackageNameFallback(t *testing.T) { + reg := x.NewRegistry() + reg.Register(x.NewType(reflect.TypeOf(resolveOrder{}), x.WithPkgPath("github.com/acme/mdp/performance"), x.WithName("Order"))) + resolver := NewResolver(reg, &Context{ + PackageName: "performance", + PackagePath: "github.com/acme/mdp/performance", + }) + + key, err := resolver.Resolve("performance.Order") + require.NoError(t, err) + require.Equal(t, "github.com/acme/mdp/performance.Order", key) +} diff --git a/repository/shape/typectx/source/resolver.go b/repository/shape/typectx/source/resolver.go new file mode 100644 index 000000000..639787d97 --- /dev/null +++ b/repository/shape/typectx/source/resolver.go @@ -0,0 +1,283 @@ +package source + +import ( + "fmt" + "go/ast" + "go/build" + "go/parser" + "go/token" + "golang.org/x/mod/modfile" + "os" + "path/filepath" + "sort" + "strings" +) + +type Config struct { + ProjectDir string + AllowedSourceRoots []string + UseGoModuleResolve bool + UseGOPATHFallback bool +} + +type Resolver struct { + projectDir string + modulePath string + replacements map[string]string + roots []string + useModule bool + useGOPATH bool +} + +func New(cfg Config) (*Resolver, error) { + projectDir := strings.TrimSpace(cfg.ProjectDir) + if projectDir == "" { + return nil, fmt.Errorf("typectx source: project dir was empty") + } + projectDir, err := filepath.Abs(projectDir) + if err != nil { + return nil, err + } + modulePath, replacements := loadModuleConfig(projectDir) + roots := NormalizeRoots(projectDir, cfg.AllowedSourceRoots) + return &Resolver{ + projectDir: projectDir, + modulePath: modulePath, + replacements: replacements, + roots: roots, + useModule: cfg.UseGoModuleResolve, + useGOPATH: cfg.UseGOPATHFallback, + }, nil +} + +func (r *Resolver) ResolvePackageDir(importPath string) (string, error) { + importPath = strings.TrimSpace(importPath) + if importPath == "" { + return "", fmt.Errorf("typectx source: empty import path") + } + if r.useModule { + if resolved := r.resolveReplace(importPath); resolved != "" { + return filepath.Clean(resolved), nil + } + if resolved := r.resolveProjectModule(importPath); resolved != "" { + return filepath.Clean(resolved), nil + } + if resolved := r.resolveModuleCache(importPath); resolved != "" { + return filepath.Clean(resolved), nil + } + } + if r.useGOPATH { + if resolved := resolveGOPATH(importPath); resolved != "" { + return filepath.Clean(resolved), nil + } + } + return "", fmt.Errorf("typectx source: package %s not resolved", importPath) +} + +func (r *Resolver) ResolveTypeFile(importPath, typeName string) (string, error) { + dir, err := r.ResolvePackageDir(importPath) + if err != nil { + return "", err + } + ok, err := IsWithinAnyRoot(dir, r.roots) + if err != nil { + return "", err + } + if !ok { + return "", fmt.Errorf("typectx source: package dir %s outside trusted roots", dir) + } + entries, err := os.ReadDir(dir) + if err != nil { + return "", err + } + fset := token.NewFileSet() + for _, entry := range entries { + if entry.IsDir() { + continue + } + name := entry.Name() + if !strings.HasSuffix(name, ".go") || strings.HasSuffix(name, "_test.go") { + continue + } + filePath := filepath.Join(dir, name) + parsed, parseErr := parser.ParseFile(fset, filePath, nil, parser.PackageClauseOnly|parser.ParseComments) + if parseErr != nil || parsed == nil { + continue + } + // Reparse full declaration only when package clause parsing succeeds. + parsed, parseErr = parser.ParseFile(fset, filePath, nil, 0) + if parseErr != nil || parsed == nil { + continue + } + for _, decl := range parsed.Decls { + gen, ok := decl.(*ast.GenDecl) + if !ok || gen.Tok != token.TYPE { + continue + } + for _, spec := range gen.Specs { + ts, ok := spec.(*ast.TypeSpec) + if ok && ts.Name != nil && ts.Name.Name == typeName { + return filePath, nil + } + } + } + } + return "", fmt.Errorf("typectx source: type %s not found in %s", typeName, importPath) +} + +func (r *Resolver) Roots() []string { + return append([]string(nil), r.roots...) +} + +func (r *Resolver) resolveReplace(importPath string) string { + oldPaths := make([]string, 0, len(r.replacements)) + for old := range r.replacements { + oldPaths = append(oldPaths, old) + } + sort.SliceStable(oldPaths, func(i, j int) bool { return len(oldPaths[i]) > len(oldPaths[j]) }) + for _, old := range oldPaths { + if importPath != old && !strings.HasPrefix(importPath, old+"/") { + continue + } + mapped := r.replacements[old] + suffix := strings.TrimPrefix(importPath, old) + suffix = strings.TrimPrefix(suffix, "/") + if suffix == "" { + return mapped + } + return filepath.Join(mapped, filepath.FromSlash(suffix)) + } + return "" +} + +func (r *Resolver) resolveProjectModule(importPath string) string { + if r.modulePath == "" { + return "" + } + if importPath != r.modulePath && !strings.HasPrefix(importPath, r.modulePath+"/") { + return "" + } + suffix := strings.TrimPrefix(importPath, r.modulePath) + suffix = strings.TrimPrefix(suffix, "/") + if suffix == "" { + return r.projectDir + } + return filepath.Join(r.projectDir, filepath.FromSlash(suffix)) +} + +func (r *Resolver) resolveModuleCache(importPath string) string { + modCache := strings.TrimSpace(os.Getenv("GOMODCACHE")) + if modCache == "" { + if out, err := os.UserCacheDir(); err == nil && out != "" { + modCache = filepath.Join(filepath.Dir(out), "pkg", "mod") + } + } + if modCache == "" { + return "" + } + pattern := filepath.Join(modCache, filepath.FromSlash(importPath)+"@*") + matches, _ := filepath.Glob(pattern) + if len(matches) == 0 { + return "" + } + sort.Strings(matches) + return matches[len(matches)-1] +} + +func resolveGOPATH(importPath string) string { + gopath := strings.TrimSpace(os.Getenv("GOPATH")) + if gopath == "" { + gopath = strings.TrimSpace(build.Default.GOPATH) + } + if gopath == "" { + return "" + } + for _, root := range filepath.SplitList(gopath) { + candidate := filepath.Join(root, "src", filepath.FromSlash(importPath)) + if info, err := os.Stat(candidate); err == nil && info.IsDir() { + return candidate + } + } + return "" +} + +func loadModuleConfig(projectDir string) (string, map[string]string) { + result := map[string]string{} + goModPath := filepath.Join(projectDir, "go.mod") + data, err := os.ReadFile(goModPath) + if err != nil { + return "", result + } + parsed, err := modfile.Parse(goModPath, data, nil) + if err != nil || parsed == nil { + return "", result + } + modulePath := "" + if parsed.Module != nil { + modulePath = strings.TrimSpace(parsed.Module.Mod.Path) + } + for _, replace := range parsed.Replace { + if replace == nil { + continue + } + oldPath := strings.TrimSpace(replace.Old.Path) + newPath := strings.TrimSpace(replace.New.Path) + if oldPath == "" || newPath == "" || replace.New.Version != "" { + continue + } + if !filepath.IsAbs(newPath) { + newPath = filepath.Join(projectDir, newPath) + } + result[oldPath] = filepath.Clean(newPath) + } + return modulePath, result +} + +func NormalizeRoots(projectDir string, allowed []string) []string { + seen := map[string]bool{} + var result []string + appendRoot := func(value string) { + value = strings.TrimSpace(value) + if value == "" { + return + } + if !filepath.IsAbs(value) { + value = filepath.Join(projectDir, value) + } + value = filepath.Clean(value) + if seen[value] { + return + } + seen[value] = true + result = append(result, value) + } + appendRoot(projectDir) + for _, item := range allowed { + appendRoot(item) + } + sort.Strings(result) + return result +} + +func IsWithinAnyRoot(candidate string, roots []string) (bool, error) { + candidate, err := filepath.Abs(candidate) + if err != nil { + return false, err + } + candidate = filepath.Clean(candidate) + for _, root := range roots { + root = filepath.Clean(root) + rel, err := filepath.Rel(root, candidate) + if err != nil { + return false, err + } + if rel == "." { + return true, nil + } + rel = filepath.ToSlash(rel) + if !strings.HasPrefix(rel, "../") { + return true, nil + } + } + return false, nil +} diff --git a/repository/shape/typectx/source/resolver_test.go b/repository/shape/typectx/source/resolver_test.go new file mode 100644 index 000000000..541c11af8 --- /dev/null +++ b/repository/shape/typectx/source/resolver_test.go @@ -0,0 +1,91 @@ +package source + +import ( + "os" + "path/filepath" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestResolver_ResolvePackageDir_UsesLocalReplace(t *testing.T) { + root := t.TempDir() + projectDir := filepath.Join(root, "project") + modelsDir := filepath.Join(root, "shared-models") + require.NoError(t, os.MkdirAll(filepath.Join(projectDir, "internal"), 0o755)) + require.NoError(t, os.MkdirAll(filepath.Join(modelsDir, "mdp"), 0o755)) + require.NoError(t, os.WriteFile(filepath.Join(modelsDir, "go.mod"), []byte("module github.com/acme/models\n\ngo 1.25\n"), 0o644)) + require.NoError(t, os.WriteFile(filepath.Join(projectDir, "go.mod"), []byte(`module example.com/project +go 1.25 +replace github.com/acme/models => ../shared-models +`), 0o644)) + require.NoError(t, os.WriteFile(filepath.Join(modelsDir, "mdp", "types.go"), []byte("package mdp\ntype Order struct{}\n"), 0o644)) + + resolver, err := New(Config{ + ProjectDir: projectDir, + UseGoModuleResolve: true, + UseGOPATHFallback: false, + }) + require.NoError(t, err) + dir, err := resolver.ResolvePackageDir("github.com/acme/models/mdp") + require.NoError(t, err) + require.Equal(t, filepath.Join(modelsDir, "mdp"), dir) +} + +func TestResolver_ResolveTypeFile_RespectsTrustedRoots(t *testing.T) { + root := t.TempDir() + projectDir := filepath.Join(root, "project") + modelsDir := filepath.Join(root, "shared-models") + require.NoError(t, os.MkdirAll(filepath.Join(projectDir, "internal"), 0o755)) + require.NoError(t, os.MkdirAll(filepath.Join(modelsDir, "mdp"), 0o755)) + require.NoError(t, os.WriteFile(filepath.Join(projectDir, "go.mod"), []byte(`module example.com/project +go 1.25 +replace github.com/acme/models => ../shared-models +`), 0o644)) + require.NoError(t, os.WriteFile(filepath.Join(modelsDir, "mdp", "types.go"), []byte("package mdp\ntype Order struct{}\n"), 0o644)) + + denyResolver, err := New(Config{ + ProjectDir: projectDir, + UseGoModuleResolve: true, + UseGOPATHFallback: false, + }) + require.NoError(t, err) + _, err = denyResolver.ResolveTypeFile("github.com/acme/models/mdp", "Order") + require.Error(t, err) + + allowResolver, err := New(Config{ + ProjectDir: projectDir, + AllowedSourceRoots: []string{modelsDir}, + UseGoModuleResolve: true, + UseGOPATHFallback: false, + }) + require.NoError(t, err) + file, err := allowResolver.ResolveTypeFile("github.com/acme/models/mdp", "Order") + require.NoError(t, err) + require.Equal(t, filepath.Join(modelsDir, "mdp", "types.go"), file) +} + +func TestResolver_ResolvePackageDir_GOPATHFallback(t *testing.T) { + root := t.TempDir() + projectDir := filepath.Join(root, "project") + gopath := filepath.Join(root, "gopath") + require.NoError(t, os.MkdirAll(filepath.Join(projectDir, "internal"), 0o755)) + require.NoError(t, os.WriteFile(filepath.Join(projectDir, "go.mod"), []byte("module example.com/project\ngo 1.25\n"), 0o644)) + legacyDir := filepath.Join(gopath, "src", "github.com", "legacy", "models") + require.NoError(t, os.MkdirAll(legacyDir, 0o755)) + require.NoError(t, os.WriteFile(filepath.Join(legacyDir, "types.go"), []byte("package models\ntype Legacy struct{}\n"), 0o644)) + + orig := os.Getenv("GOPATH") + require.NoError(t, os.Setenv("GOPATH", gopath)) + defer func() { _ = os.Setenv("GOPATH", orig) }() + + resolver, err := New(Config{ + ProjectDir: projectDir, + UseGoModuleResolve: false, + UseGOPATHFallback: true, + }) + require.NoError(t, err) + dir, err := resolver.ResolvePackageDir("github.com/legacy/models") + require.NoError(t, err) + require.Equal(t, legacyDir, dir) +} diff --git a/repository/shape/validate/relation.go b/repository/shape/validate/relation.go new file mode 100644 index 000000000..31aee9357 --- /dev/null +++ b/repository/shape/validate/relation.go @@ -0,0 +1,140 @@ +package validate + +import ( + "fmt" + "strings" + + "github.com/viant/datly/view" +) + +// ValidateRelations validates that relation link columns can be resolved on both +// parent and referenced views. It accepts alias/source/field variants and +// namespace-qualified forms (e.g. t.ID -> ID). +func ValidateRelations(resource *view.Resource, targets ...*view.View) error { + if resource == nil { + return nil + } + views := targets + if len(views) == 0 { + views = resource.Views + } + index := resource.Views.Index() + var issues []string + for _, parent := range views { + if parent == nil { + continue + } + parentIndex := view.Columns(parent.Columns).Index(parent.CaseFormat) + for _, rel := range parent.With { + if rel == nil || rel.Of == nil { + continue + } + ref := &rel.Of.View + if ref.Ref != "" { + if lookup, err := index.Lookup(ref.Ref); err == nil && lookup != nil { + ref = lookup + } + } + refIndex := view.Columns(ref.Columns).Index(ref.CaseFormat) + pairCount := len(rel.On) + if len(rel.Of.On) > pairCount { + pairCount = len(rel.Of.On) + } + for i := 0; i < pairCount; i++ { + var parentLink, refLink *view.Link + if i < len(rel.On) { + parentLink = rel.On[i] + } + if i < len(rel.Of.On) { + refLink = rel.Of.On[i] + } + + if missing := missingColumn(parentIndex, parentLink); missing != "" { + issues = append(issues, fmt.Sprintf("relation %q (parent=%q holder=%q link=%d): missing parent column %q", relName(rel, i), parent.Name, rel.Holder, i, missing)) + } + if missing := missingColumn(refIndex, refLink); missing != "" { + issues = append(issues, fmt.Sprintf("relation %q (parent=%q ref=%q holder=%q link=%d): missing ref column %q", relName(rel, i), parent.Name, ref.Name, rel.Holder, i, missing)) + } + } + } + } + if len(issues) == 0 { + return nil + } + return fmt.Errorf("shape relation validation failed:\n- %s", strings.Join(issues, "\n- ")) +} + +func missingColumn(index view.NamedColumns, link *view.Link) string { + if link == nil { + return "" + } + for _, candidate := range linkCandidates(link) { + if strings.TrimSpace(candidate) == "" { + continue + } + if _, err := index.Lookup(candidate); err == nil { + return "" + } + } + for _, candidate := range linkCandidates(link) { + if strings.TrimSpace(candidate) != "" { + return candidate + } + } + return "" +} + +func linkCandidates(link *view.Link) []string { + if link == nil { + return nil + } + var result []string + add := func(v string) { + v = strings.TrimSpace(trimIdentifier(v)) + if v == "" { + return + } + result = append(result, v) + if i := strings.LastIndex(v, "."); i != -1 && i < len(v)-1 { + result = append(result, v[i+1:]) + } + } + add(link.Column) + if link.Namespace != "" && link.Column != "" { + add(link.Namespace + "." + link.Column) + } + add(link.Field) + return dedupe(result) +} + +func trimIdentifier(value string) string { + value = strings.TrimSpace(value) + value = strings.Trim(value, "`") + value = strings.Trim(value, "\"") + value = strings.Trim(value, "'") + return value +} + +func dedupe(values []string) []string { + seen := map[string]bool{} + result := make([]string, 0, len(values)) + for _, value := range values { + key := strings.ToLower(strings.TrimSpace(value)) + if key == "" || seen[key] { + continue + } + seen[key] = true + result = append(result, value) + } + return result +} + +func relName(rel *view.Relation, idx int) string { + if rel == nil { + return fmt.Sprintf("#%d", idx) + } + if strings.TrimSpace(rel.Name) != "" { + return rel.Name + } + return fmt.Sprintf("#%d", idx) +} diff --git a/repository/shape/validate/relation_test.go b/repository/shape/validate/relation_test.go new file mode 100644 index 000000000..e10317882 --- /dev/null +++ b/repository/shape/validate/relation_test.go @@ -0,0 +1,70 @@ +package validate + +import ( + "testing" + + "github.com/stretchr/testify/require" + "github.com/viant/datly/view" + "github.com/viant/datly/view/state" +) + +func TestValidateRelations_AllowsAliasSourceAndNamespace(t *testing.T) { + parent := &view.View{ + Name: "vendor", + Columns: view.Columns{ + view.NewColumn("ID", "int", nil, false), + }, + } + child := &view.View{ + Name: "products", + Columns: view.Columns{ + view.NewColumn("VendorID", "int", nil, false, view.WithColumnTag(`source:"VENDOR_ID"`)), + }, + } + parent.With = []*view.Relation{{ + Name: "products", + Cardinality: state.Many, + Holder: "Products", + On: view.Links{&view.Link{Column: "vendor.ID"}}, + Of: &view.ReferenceView{ + View: *child, + On: view.Links{&view.Link{Column: "VENDOR_ID"}}, + }, + }} + resource := view.EmptyResource() + resource.Views = append(resource.Views, parent, child) + require.NoError(t, ValidateRelations(resource, parent)) +} + +func TestValidateRelations_DetailedMissingError(t *testing.T) { + parent := &view.View{ + Name: "vendor", + Columns: view.Columns{ + view.NewColumn("ID", "int", nil, false), + }, + } + child := &view.View{ + Name: "products", + Columns: view.Columns{ + view.NewColumn("VendorID", "int", nil, false), + }, + } + parent.With = []*view.Relation{{ + Name: "products", + Cardinality: state.Many, + Holder: "Products", + On: view.Links{&view.Link{Column: "MISSING_PARENT"}}, + Of: &view.ReferenceView{ + View: *child, + On: view.Links{&view.Link{Column: "MISSING_CHILD"}}, + }, + }} + resource := view.EmptyResource() + resource.Views = append(resource.Views, parent, child) + err := ValidateRelations(resource, parent) + require.Error(t, err) + require.Contains(t, err.Error(), "missing parent column \"MISSING_PARENT\"") + require.Contains(t, err.Error(), "missing ref column \"MISSING_CHILD\"") + require.Contains(t, err.Error(), "parent=\"vendor\"") + require.Contains(t, err.Error(), "ref=\"products\"") +} diff --git a/repository/shape/xgen/generator.go b/repository/shape/xgen/generator.go new file mode 100644 index 000000000..d0fc419a8 --- /dev/null +++ b/repository/shape/xgen/generator.go @@ -0,0 +1,684 @@ +package xgen + +import ( + "fmt" + "go/ast" + "os" + "path/filepath" + "reflect" + "sort" + "strings" + + "github.com/viant/datly/repository/shape/dql/shape" + "github.com/viant/datly/repository/shape/typectx" + "github.com/viant/datly/repository/shape/typectx/source" + "github.com/viant/x" + xreflectloader "github.com/viant/x/loader/xreflect" + "github.com/viant/x/syntetic" + "github.com/viant/x/syntetic/model" +) + +// GenerateFromDQLShape emits Go structs from DQL shape using viant/x registry. +func GenerateFromDQLShape(doc *shape.Document, cfg *Config) (*Result, error) { + if doc == nil || doc.Root == nil { + return nil, fmt.Errorf("shape xgen: nil document") + } + if cfg == nil { + cfg = &Config{} + } + hydrateConfigFromTypeContext(doc, cfg) + applyDefaults(cfg) + projectDir, packageDir, err := resolvePaths(cfg.ProjectDir, cfg.PackageDir) + if err != nil { + return nil, err + } + packageName := resolvePackageName(cfg.PackageName, packageDir) + packagePath, err := resolvePackagePath(cfg.PackagePath, projectDir, packageDir) + if err != nil { + return nil, err + } + fileName := cfg.FileName + if strings.TrimSpace(fileName) == "" { + fileName = "shapes_gen.go" + } + registry := cfg.Registry + if registry == nil { + registry = x.NewRegistry() + } + views := extractViews(doc.Root) + routeTypes := extractRouteIO(doc.Root) + if len(views) == 0 && len(routeTypes) == 0 { + return nil, fmt.Errorf("shape xgen: no view or route io declarations") + } + typeNames := make([]string, 0, len(views)+len(routeTypes)) + registered := map[string]bool{} + for _, view := range views { + typeName := viewTypeName(cfg, view) + if registered[typeName] { + continue + } + registered[typeName] = true + if err = registerShapeType(registry, packagePath, typeName, buildStructType(view.columns)); err != nil { + return nil, err + } + typeNames = append(typeNames, typeName) + } + for _, ioType := range routeTypes { + typeName := routeTypeName(cfg, ioType) + if typeName == "" || registered[typeName] { + continue + } + registered[typeName] = true + if err = registerShapeType(registry, packagePath, typeName, buildStructType(ioType.fields)); err != nil { + return nil, err + } + typeNames = append(typeNames, typeName) + } + namespace, err := syntetic.FromRegistry(registry) + if err != nil { + return nil, err + } + namespace.PkgName = packageName + namespace.PkgPath = packagePath + files, err := namespace.BuildFiles(model.RenderOptions{}) + if err != nil { + return nil, err + } + goFile := files[packagePath] + if goFile == nil { + return nil, fmt.Errorf("shape xgen: missing generated package file for %s", packagePath) + } + source, err := goFile.Render() + if err != nil { + return nil, err + } + if err = os.MkdirAll(packageDir, 0o755); err != nil { + return nil, err + } + dest := filepath.Join(packageDir, fileName) + if exists, checkErr := fileExists(dest); checkErr != nil { + return nil, checkErr + } else if exists && !cfg.AllowUnsafeRewrite { + if issues := rewriteSafetyIssues(doc, cfg, projectDir); len(issues) > 0 && (cfg.StrictProvenance == nil || *cfg.StrictProvenance) { + return nil, fmt.Errorf("shape xgen: rewrite blocked by type provenance safety: %s", strings.Join(issues, "; ")) + } + merged, mergeErr := mergeGeneratedShapes(dest, []byte(source), typeNames) + if mergeErr != nil { + return nil, mergeErr + } + source = string(merged) + } + if err = writeAtomic(dest, []byte(source), 0o644); err != nil { + return nil, err + } + sort.Strings(typeNames) + return &Result{ + FilePath: dest, + PackagePath: packagePath, + PackageName: packageName, + Types: typeNames, + }, nil +} + +func rewriteSafetyIssues(doc *shape.Document, cfg *Config, projectDir string) []string { + if doc == nil || len(doc.TypeResolutions) == 0 { + return nil + } + policy := newRewritePolicy(cfg, projectDir) + srcResolver, _ := source.New(source.Config{ + ProjectDir: projectDir, + AllowedSourceRoots: policy.roots, + UseGoModuleResolve: policy.useModule, + UseGOPATHFallback: policy.useGOPATH, + }) + var issues []string + for i := range doc.TypeResolutions { + resolution := &doc.TypeResolutions[i] + if srcResolver != nil && strings.TrimSpace(resolution.Provenance.File) == "" { + pkg := inferResolutionPackage(*resolution, doc.TypeContext) + name := typeNameFromKey(resolution.ResolvedKey) + if name == "" { + name = strings.TrimSpace(resolution.Expression) + } + if pkg != "" && name != "" { + if file, err := srcResolver.ResolveTypeFile(pkg, name); err == nil { + resolution.Provenance.File = file + if resolution.Provenance.Kind == "" || strings.EqualFold(resolution.Provenance.Kind, "registry") { + resolution.Provenance.Kind = "ast_type" + } + } + } + } + if issue := resolutionSafetyIssue(*resolution, policy); issue != "" { + issues = append(issues, issue) + } + } + sort.Strings(issues) + return uniqueStrings(issues) +} + +func hydrateConfigFromTypeContext(doc *shape.Document, cfg *Config) { + if doc == nil || cfg == nil || doc.TypeContext == nil { + return + } + if cfg.PackageDir == "" { + cfg.PackageDir = strings.TrimSpace(doc.TypeContext.PackageDir) + } + if cfg.PackageName == "" { + cfg.PackageName = strings.TrimSpace(doc.TypeContext.PackageName) + } + if cfg.PackagePath == "" { + cfg.PackagePath = strings.TrimSpace(doc.TypeContext.PackagePath) + } +} + +func inferResolutionPackage(resolution typectx.Resolution, ctx *typectx.Context) string { + pkg := strings.TrimSpace(resolution.Provenance.Package) + if pkg != "" { + return pkg + } + pkg = packageOfKey(resolution.ResolvedKey) + if pkg != "" { + return pkg + } + if ctx != nil { + if pkg = strings.TrimSpace(ctx.PackagePath); pkg != "" { + return pkg + } + if pkg = strings.TrimSpace(ctx.DefaultPackage); pkg != "" { + return pkg + } + } + return "" +} + +func resolutionSafetyIssue(resolution typectx.Resolution, policy rewritePolicy) string { + kind := strings.TrimSpace(strings.ToLower(resolution.Provenance.Kind)) + if kind == "" { + kind = "registry" + } + if !policy.allowedKinds[kind] { + return fmt.Sprintf("expression=%q kind=%q", resolution.Expression, resolution.Provenance.Kind) + } + + sourceFile := strings.TrimSpace(resolution.Provenance.File) + if sourceFile == "" { + return "" + } + if !filepath.IsAbs(sourceFile) { + sourceFile = filepath.Clean(filepath.Join(policy.projectDir, sourceFile)) + } + if safe, err := source.IsWithinAnyRoot(sourceFile, policy.roots); err != nil || !safe { + return fmt.Sprintf("expression=%q source=%q outside_trusted_roots", resolution.Expression, resolution.Provenance.File) + } + return "" +} + +type rewritePolicy struct { + projectDir string + allowedKinds map[string]bool + roots []string + useModule bool + useGOPATH bool +} + +func newRewritePolicy(cfg *Config, projectDir string) rewritePolicy { + allowedKinds := map[string]bool{ + "builtin": true, + "resource_type": true, + "ast_type": true, + } + if len(cfg.AllowedProvenanceKinds) > 0 { + allowedKinds = map[string]bool{} + for _, item := range cfg.AllowedProvenanceKinds { + item = strings.TrimSpace(strings.ToLower(item)) + if item != "" { + allowedKinds[item] = true + } + } + } + useModule := true + if cfg.UseGoModuleResolve != nil { + useModule = *cfg.UseGoModuleResolve + } + useGOPATH := true + if cfg.UseGOPATHFallback != nil { + useGOPATH = *cfg.UseGOPATHFallback + } + return rewritePolicy{ + projectDir: projectDir, + allowedKinds: allowedKinds, + roots: source.NormalizeRoots(projectDir, cfg.AllowedSourceRoots), + useModule: useModule, + useGOPATH: useGOPATH, + } +} + +func typeNameFromKey(key string) string { + index := strings.LastIndex(key, ".") + if index == -1 || index+1 >= len(key) { + return "" + } + return key[index+1:] +} + +func packageOfKey(key string) string { + index := strings.LastIndex(key, ".") + if index == -1 { + return "" + } + return key[:index] +} + +func uniqueStrings(items []string) []string { + if len(items) < 2 { + return items + } + result := items[:0] + var previous string + for i, item := range items { + if i == 0 || item != previous { + result = append(result, item) + } + previous = item + } + return result +} + +func registerShapeType(registry *x.Registry, packagePath string, typeName string, rType reflect.Type) error { + st, err := xreflectloader.BuildType(rType, + xreflectloader.WithPackagePath(packagePath), + xreflectloader.WithNamePolicy(func(reflect.Type) (string, bool) { + return typeName, false + })) + if err != nil { + return fmt.Errorf("shape xgen: build type %s failed: %w", typeName, err) + } + st.Name = typeName + st.PkgPath = packagePath + if st.TypeSpec != nil { + st.TypeSpec.Name = ast.NewIdent(typeName) + } + registry.Register(x.NewType(rType, + x.WithName(typeName), + x.WithPkgPath(packagePath), + x.WithSyntheticType(st))) + return nil +} + +type viewDescriptor struct { + name any + schemaName any + columns []columnDescriptor +} + +type ioTypeKind string + +const ( + ioTypeInput ioTypeKind = "input" + ioTypeOutput ioTypeKind = "output" +) + +type routeIODescriptor struct { + kind ioTypeKind + routeName string + routeURI string + routeRef string + typeName string + fields []columnDescriptor +} + +type columnDescriptor struct { + name string + dataType string +} + +func extractViews(root map[string]any) []viewDescriptor { + resource := asMap(root["Resource"]) + if resource == nil { + return nil + } + items := asSlice(resource["Views"]) + result := make([]viewDescriptor, 0, len(items)) + for _, item := range items { + view := asMap(item) + if view == nil { + continue + } + schema := asMap(view["Schema"]) + descriptor := viewDescriptor{ + name: view["Name"], + schemaName: nil, + } + if schema != nil { + descriptor.schemaName = schema["Name"] + } + descriptor.columns = extractColumns(view) + result = append(result, descriptor) + } + return result +} + +func extractColumns(view map[string]any) []columnDescriptor { + var result []columnDescriptor + if columns := asSlice(view["Columns"]); len(columns) > 0 { + for _, item := range columns { + column := asMap(item) + if column == nil { + continue + } + name := firstNonEmpty(asString(column["Name"]), asString(column["Column"])) + if name == "" { + continue + } + result = append(result, columnDescriptor{name: name, dataType: asString(column["DataType"])}) + } + } + if cfg := asMap(view["ColumnsConfig"]); len(cfg) > 0 { + keys := make([]string, 0, len(cfg)) + for k := range cfg { + keys = append(keys, k) + } + sort.Strings(keys) + for _, key := range keys { + item := asMap(cfg[key]) + if item == nil { + item = map[string]any{} + } + name := firstNonEmpty(asString(item["Name"]), key) + result = append(result, columnDescriptor{name: name, dataType: asString(item["DataType"])}) + } + } + if len(result) == 0 { + result = append(result, columnDescriptor{name: "ID", dataType: "int"}) + } + return result +} + +func extractRouteIO(root map[string]any) []routeIODescriptor { + var result []routeIODescriptor + for _, item := range asSlice(root["Routes"]) { + route := asMap(item) + if route == nil { + continue + } + meta := routeIODescriptor{ + routeName: asString(route["Name"]), + routeURI: asString(route["URI"]), + } + if routeView := asMap(route["View"]); routeView != nil { + meta.routeRef = asString(routeView["Ref"]) + } + if input := asMap(route["Input"]); input != nil { + entry := meta + entry.kind = ioTypeInput + entry.typeName = nestedTypeName(input) + entry.fields = extractIOFields(input) + result = append(result, entry) + } + if output := asMap(route["Output"]); output != nil { + entry := meta + entry.kind = ioTypeOutput + entry.typeName = nestedTypeName(output) + entry.fields = extractIOFields(output) + result = append(result, entry) + } + } + return result +} + +func nestedTypeName(io map[string]any) string { + aType := asMap(io["Type"]) + if aType == nil { + return "" + } + return asString(aType["Name"]) +} + +func extractIOFields(io map[string]any) []columnDescriptor { + parameters := asSlice(io["Parameters"]) + if len(parameters) == 0 { + if t := asMap(io["Type"]); t != nil { + parameters = asSlice(t["Parameters"]) + } + } + fields := make([]columnDescriptor, 0, len(parameters)) + for _, item := range parameters { + param := asMap(item) + if param == nil { + continue + } + name := asString(param["Name"]) + if name == "" { + continue + } + dataType := "" + if schema := asMap(param["Schema"]); schema != nil { + dataType = asString(schema["DataType"]) + } + fields = append(fields, columnDescriptor{name: name, dataType: dataType}) + } + if len(fields) == 0 { + fields = append(fields, columnDescriptor{name: "ID", dataType: "int"}) + } + return fields +} + +func buildStructType(columns []columnDescriptor) reflect.Type { + if len(columns) == 0 { + columns = []columnDescriptor{{name: "ID", dataType: "int"}} + } + fields := make([]reflect.StructField, 0, len(columns)) + used := map[string]int{} + for _, column := range columns { + fieldName := exportedName(column.name) + if fieldName == "" { + fieldName = "Field" + } + if count := used[fieldName]; count > 0 { + fieldName = fmt.Sprintf("%s%d", fieldName, count+1) + } + used[fieldName]++ + fields = append(fields, reflect.StructField{ + Name: fieldName, + Type: parseType(column.dataType), + Tag: reflect.StructTag(fmt.Sprintf(`json:"%s,omitempty" sqlx:"%s"`, strings.ToLower(fieldName), column.name)), + }) + } + return reflect.StructOf(fields) +} + +func parseType(dataType string) reflect.Type { + dataType = strings.TrimSpace(dataType) + if dataType == "" { + return reflect.TypeOf("") + } + switch { + case strings.HasPrefix(dataType, "[]"): + return reflect.SliceOf(parseType(strings.TrimPrefix(dataType, "[]"))) + case strings.HasPrefix(dataType, "*"): + return reflect.PointerTo(parseType(strings.TrimPrefix(dataType, "*"))) + } + lowered := strings.ToLower(dataType) + switch lowered { + case "string", "varchar", "text": + return reflect.TypeOf("") + case "bool", "boolean": + return reflect.TypeOf(true) + case "int", "integer": + return reflect.TypeOf(int(0)) + case "int64", "bigint": + return reflect.TypeOf(int64(0)) + case "int32": + return reflect.TypeOf(int32(0)) + case "float", "float64", "double", "decimal": + return reflect.TypeOf(float64(0)) + case "float32": + return reflect.TypeOf(float32(0)) + case "bytes", "[]byte", "blob": + return reflect.TypeOf([]byte{}) + default: + return reflect.TypeOf("") + } +} + +func exportedName(name string) string { + name = strings.TrimSpace(name) + if name == "" { + return "" + } + var parts []string + current := strings.Builder{} + flush := func() { + if current.Len() == 0 { + return + } + parts = append(parts, current.String()) + current.Reset() + } + for _, r := range name { + if (r >= 'a' && r <= 'z') || (r >= 'A' && r <= 'Z') || (r >= '0' && r <= '9') { + current.WriteRune(r) + } else { + flush() + } + } + flush() + for i, item := range parts { + if item == strings.ToUpper(item) { + parts[i] = strings.ToUpper(item[:1]) + strings.ToLower(item[1:]) + } else { + parts[i] = strings.ToUpper(item[:1]) + item[1:] + } + } + result := strings.Join(parts, "") + if result == "" { + return "" + } + if result[0] >= '0' && result[0] <= '9' { + result = "N" + result + } + return result +} + +func applyDefaults(cfg *Config) { + if cfg.ViewSuffix == "" { + cfg.ViewSuffix = "View" + } + if cfg.InputSuffix == "" { + cfg.InputSuffix = "Input" + } + if cfg.OutputSuffix == "" { + cfg.OutputSuffix = "Output" + } + if cfg.UseGoModuleResolve == nil { + value := true + cfg.UseGoModuleResolve = &value + } + if cfg.UseGOPATHFallback == nil { + value := true + cfg.UseGOPATHFallback = &value + } + if cfg.StrictProvenance == nil { + value := true + cfg.StrictProvenance = &value + } +} + +func viewTypeName(cfg *Config, view viewDescriptor) string { + ctx := ViewTypeContext{ + ViewName: asString(view.name), + SchemaName: asString(view.schemaName), + } + if cfg.ViewTypeNamer != nil { + if name := strings.TrimSpace(cfg.ViewTypeNamer(ctx)); name != "" { + return cfg.TypePrefix + exportedName(name) + } + } + base := firstNonEmpty(ctx.SchemaName, ctx.ViewName) + if base == "" { + base = cfg.ViewSuffix + } else if !hasCaseInsensitiveSuffix(base, cfg.ViewSuffix) { + base += cfg.ViewSuffix + } + return cfg.TypePrefix + exportedName(base) +} + +func routeTypeName(cfg *Config, route routeIODescriptor) string { + ctx := RouteTypeContext{ + RouteName: route.routeName, + RouteURI: route.routeURI, + RouteRef: route.routeRef, + TypeName: route.typeName, + } + var custom string + switch route.kind { + case ioTypeInput: + if cfg.InputTypeNamer != nil { + custom = cfg.InputTypeNamer(ctx) + } + case ioTypeOutput: + if cfg.OutputTypeNamer != nil { + custom = cfg.OutputTypeNamer(ctx) + } + } + if strings.TrimSpace(custom) != "" { + return cfg.TypePrefix + exportedName(custom) + } + base := firstNonEmpty(ctx.TypeName, ctx.RouteName, ctx.RouteRef, "Route") + suffix := cfg.OutputSuffix + if route.kind == ioTypeInput { + suffix = cfg.InputSuffix + } + if !hasCaseInsensitiveSuffix(base, suffix) { + base += suffix + } + return cfg.TypePrefix + exportedName(base) +} + +func hasCaseInsensitiveSuffix(value, suffix string) bool { + if suffix == "" { + return true + } + return strings.HasSuffix(strings.ToLower(value), strings.ToLower(suffix)) +} + +func firstNonEmpty(values ...string) string { + for _, item := range values { + if strings.TrimSpace(item) != "" { + return item + } + } + return "" +} + +func asMap(raw any) map[string]any { + if value, ok := raw.(map[string]any); ok { + return value + } + if value, ok := raw.(map[any]any); ok { + out := map[string]any{} + for key, item := range value { + out[fmt.Sprint(key)] = item + } + return out + } + return nil +} + +func asSlice(raw any) []any { + if value, ok := raw.([]any); ok { + return value + } + return nil +} + +func asString(raw any) string { + if raw == nil { + return "" + } + if value, ok := raw.(string); ok { + return value + } + return fmt.Sprint(raw) +} diff --git a/repository/shape/xgen/generator_test.go b/repository/shape/xgen/generator_test.go new file mode 100644 index 000000000..eb8f35b13 --- /dev/null +++ b/repository/shape/xgen/generator_test.go @@ -0,0 +1,480 @@ +package xgen + +import ( + "go/parser" + "go/token" + "os" + "path/filepath" + "strings" + "testing" + + dqlshape "github.com/viant/datly/repository/shape/dql/shape" + "github.com/viant/datly/repository/shape/typectx" +) + +func TestGenerateFromDQLShape(t *testing.T) { + projectDir := t.TempDir() + if err := os.WriteFile(filepath.Join(projectDir, "go.mod"), []byte("module example.com/demo\n\ngo 1.25.0\n"), 0o644); err != nil { + t.Fatalf("write go.mod failed: %v", err) + } + doc := &dqlshape.Document{Root: map[string]any{ + "Routes": []any{ + map[string]any{ + "Name": "orders", + "URI": "/orders", + "View": map[string]any{"Ref": "orders"}, + "Input": map[string]any{ + "Type": map[string]any{"Name": "OrdersFilter"}, + "Parameters": []any{ + map[string]any{ + "Name": "status", + "Schema": map[string]any{ + "DataType": "string", + }, + }, + }, + }, + "Output": map[string]any{ + "Type": map[string]any{"Name": "OrdersPayload"}, + "Parameters": []any{ + map[string]any{ + "Name": "total", + "Schema": map[string]any{ + "DataType": "int", + }, + }, + }, + }, + }, + }, + "Resource": map[string]any{ + "Views": []any{ + map[string]any{ + "Name": "orders", + "Schema": map[string]any{ + "Name": "OrderView", + }, + "ColumnsConfig": map[string]any{ + "ID": map[string]any{"Name": "ID", "DataType": "int"}, + "NAME": map[string]any{"Name": "NAME", "DataType": "string"}, + }, + }, + }, + }, + }} + result, err := GenerateFromDQLShape(doc, &Config{ + ProjectDir: projectDir, + PackageDir: "internal/gen", + PackageName: "gen", + FileName: "shapes_gen.go", + TypePrefix: "DQL", + }) + if err != nil { + t.Fatalf("generate failed: %v", err) + } + if result == nil { + t.Fatalf("nil result") + } + if len(result.Types) == 0 { + t.Fatalf("expected generated types") + } + if _, err = os.Stat(result.FilePath); err != nil { + t.Fatalf("generated file missing: %v", err) + } + data, err := os.ReadFile(result.FilePath) + if err != nil { + t.Fatalf("read generated file failed: %v", err) + } + source := string(data) + if !strings.Contains(source, "type DQLOrderView struct") { + t.Fatalf("expected generated type in source, got:\n%s", source) + } + if !strings.Contains(source, "type DQLOrdersFilterInput struct") || !strings.Contains(source, "type DQLOrdersPayloadOutput struct") { + t.Fatalf("expected io types in source, got:\n%s", source) + } + if !strings.Contains(source, "Id") || !strings.Contains(source, "Name") { + t.Fatalf("expected generated fields in source, got:\n%s", source) + } + fset := token.NewFileSet() + if _, err = parser.ParseFile(fset, result.FilePath, source, parser.AllErrors); err != nil { + t.Fatalf("generated file parse failed: %v", err) + } +} + +func TestGenerateFromDQLShape_CustomTypeNamers(t *testing.T) { + projectDir := t.TempDir() + if err := os.WriteFile(filepath.Join(projectDir, "go.mod"), []byte("module example.com/demo\n\ngo 1.25.0\n"), 0o644); err != nil { + t.Fatalf("write go.mod failed: %v", err) + } + doc := &dqlshape.Document{Root: map[string]any{ + "Routes": []any{ + map[string]any{ + "Name": "orders", + "Input": map[string]any{ + "Parameters": []any{map[string]any{"Name": "q", "Schema": map[string]any{"DataType": "string"}}}, + }, + "Output": map[string]any{ + "Parameters": []any{map[string]any{"Name": "count", "Schema": map[string]any{"DataType": "int"}}}, + }, + }, + }, + "Resource": map[string]any{ + "Views": []any{ + map[string]any{"Name": "orders", "ColumnsConfig": map[string]any{"ID": map[string]any{"Name": "ID", "DataType": "int"}}}, + }, + }, + }} + result, err := GenerateFromDQLShape(doc, &Config{ + ProjectDir: projectDir, + PackageDir: "internal/gen", + ViewTypeNamer: func(ctx ViewTypeContext) string { + return "DataOrders" + }, + InputTypeNamer: func(ctx RouteTypeContext) string { + return "ReqOrders" + }, + OutputTypeNamer: func(ctx RouteTypeContext) string { + return "ResOrders" + }, + }) + if err != nil { + t.Fatalf("generate failed: %v", err) + } + data, err := os.ReadFile(result.FilePath) + if err != nil { + t.Fatalf("read generated file failed: %v", err) + } + source := string(data) + if !strings.Contains(source, "type DataOrders struct") { + t.Fatalf("missing custom view type: %s", source) + } + if !strings.Contains(source, "type ReqOrders struct") { + t.Fatalf("missing custom input type: %s", source) + } + if !strings.Contains(source, "type ResOrders struct") { + t.Fatalf("missing custom output type: %s", source) + } +} + +func TestGenerateFromDQLShape_BlocksUnsafeRewriteByProvenance(t *testing.T) { + projectDir := t.TempDir() + if err := os.WriteFile(filepath.Join(projectDir, "go.mod"), []byte("module example.com/demo\n\ngo 1.25.0\n"), 0o644); err != nil { + t.Fatalf("write go.mod failed: %v", err) + } + packageDir := filepath.Join(projectDir, "internal", "gen") + if err := os.MkdirAll(packageDir, 0o755); err != nil { + t.Fatalf("mkdir failed: %v", err) + } + dest := filepath.Join(packageDir, "shapes_gen.go") + if err := os.WriteFile(dest, []byte("package gen\n"), 0o644); err != nil { + t.Fatalf("seed file failed: %v", err) + } + + doc := &dqlshape.Document{ + Root: map[string]any{ + "Resource": map[string]any{ + "Views": []any{ + map[string]any{"Name": "orders", "ColumnsConfig": map[string]any{"ID": map[string]any{"Name": "ID", "DataType": "int"}}}, + }, + }, + }, + TypeResolutions: []typectx.Resolution{ + { + Expression: "Fee", + Provenance: typectx.Provenance{Kind: "registry"}, + }, + }, + } + _, err := GenerateFromDQLShape(doc, &Config{ + ProjectDir: projectDir, + PackageDir: "internal/gen", + PackageName: "gen", + FileName: "shapes_gen.go", + }) + if err == nil || !strings.Contains(err.Error(), "rewrite blocked") { + t.Fatalf("expected rewrite blocked error, got: %v", err) + } +} + +func TestGenerateFromDQLShape_AllowsUnsafeRewriteWithOverride(t *testing.T) { + projectDir := t.TempDir() + if err := os.WriteFile(filepath.Join(projectDir, "go.mod"), []byte("module example.com/demo\n\ngo 1.25.0\n"), 0o644); err != nil { + t.Fatalf("write go.mod failed: %v", err) + } + packageDir := filepath.Join(projectDir, "internal", "gen") + if err := os.MkdirAll(packageDir, 0o755); err != nil { + t.Fatalf("mkdir failed: %v", err) + } + dest := filepath.Join(packageDir, "shapes_gen.go") + if err := os.WriteFile(dest, []byte("package gen\n"), 0o644); err != nil { + t.Fatalf("seed file failed: %v", err) + } + + doc := &dqlshape.Document{ + Root: map[string]any{ + "Resource": map[string]any{ + "Views": []any{ + map[string]any{"Name": "orders", "ColumnsConfig": map[string]any{"ID": map[string]any{"Name": "ID", "DataType": "int"}}}, + }, + }, + }, + TypeResolutions: []typectx.Resolution{ + { + Expression: "Fee", + Provenance: typectx.Provenance{Kind: "registry"}, + }, + }, + } + result, err := GenerateFromDQLShape(doc, &Config{ + ProjectDir: projectDir, + PackageDir: "internal/gen", + PackageName: "gen", + FileName: "shapes_gen.go", + AllowUnsafeRewrite: true, + }) + if err != nil { + t.Fatalf("expected override rewrite success, got: %v", err) + } + if result == nil || result.FilePath == "" { + t.Fatalf("expected generated result") + } +} + +func TestGenerateFromDQLShape_MergesIntoExistingFile(t *testing.T) { + projectDir := t.TempDir() + if err := os.WriteFile(filepath.Join(projectDir, "go.mod"), []byte("module example.com/demo\n\ngo 1.25.0\n"), 0o644); err != nil { + t.Fatalf("write go.mod failed: %v", err) + } + packageDir := filepath.Join(projectDir, "internal", "gen") + if err := os.MkdirAll(packageDir, 0o755); err != nil { + t.Fatalf("mkdir failed: %v", err) + } + dest := filepath.Join(packageDir, "shapes_gen.go") + initial := `package gen + +type DQLOrderView struct { + Old string ` + "`json:\"old,omitempty\"`" + ` +} + +func TestGenerateFromDQLShape_UsesTypeContextPackageDefaults(t *testing.T) { + projectDir := t.TempDir() + if err := os.WriteFile(filepath.Join(projectDir, "go.mod"), []byte("module example.com/demo\n\ngo 1.25.0\n"), 0o644); err != nil { + t.Fatalf("write go.mod failed: %v", err) + } + doc := &dqlshape.Document{ + TypeContext: &typectx.Context{ + PackageDir: "pkg/platform/taxonomy", + PackageName: "taxonomy", + PackagePath: "example.com/demo/pkg/platform/taxonomy", + }, + Root: map[string]any{ + "Resource": map[string]any{ + "Views": []any{ + map[string]any{ + "Name": "orders", + "ColumnsConfig": map[string]any{ + "ID": map[string]any{"Name": "ID", "DataType": "int"}, + }, + }, + }, + }, + }, + } + result, err := GenerateFromDQLShape(doc, &Config{ProjectDir: projectDir}) + if err != nil { + t.Fatalf("generate failed: %v", err) + } + if result == nil { + t.Fatalf("expected result") + } + if result.PackageName != "taxonomy" { + t.Fatalf("expected package name taxonomy, got %q", result.PackageName) + } + if result.PackagePath != "example.com/demo/pkg/platform/taxonomy" { + t.Fatalf("expected package path from type context, got %q", result.PackagePath) + } + if !strings.Contains(filepath.ToSlash(result.FilePath), "/pkg/platform/taxonomy/") { + t.Fatalf("expected file under type-context package dir, got %s", result.FilePath) + } +} + +func TestGenerateFromDQLShape_ProvenanceEnrichment_WithReplaceAndTypeContextPackagePath(t *testing.T) { + root := t.TempDir() + projectDir := filepath.Join(root, "project") + modelsDir := filepath.Join(root, "shared-models") + if err := os.MkdirAll(filepath.Join(projectDir, "internal", "gen"), 0o755); err != nil { + t.Fatalf("mkdir project failed: %v", err) + } + if err := os.MkdirAll(filepath.Join(modelsDir, "mdp"), 0o755); err != nil { + t.Fatalf("mkdir models failed: %v", err) + } + if err := os.WriteFile(filepath.Join(projectDir, "go.mod"), []byte("module example.com/project\n\ngo 1.25\nreplace github.com/acme/models => ../shared-models\n"), 0o644); err != nil { + t.Fatalf("write project go.mod failed: %v", err) + } + if err := os.WriteFile(filepath.Join(modelsDir, "go.mod"), []byte("module github.com/acme/models\n\ngo 1.25\n"), 0o644); err != nil { + t.Fatalf("write models go.mod failed: %v", err) + } + if err := os.WriteFile(filepath.Join(modelsDir, "mdp", "types.go"), []byte("package mdp\ntype Order struct{}\n"), 0o644); err != nil { + t.Fatalf("write types.go failed: %v", err) + } + dest := filepath.Join(projectDir, "internal", "gen", "shapes_gen.go") + if err := os.WriteFile(dest, []byte("package gen\n"), 0o644); err != nil { + t.Fatalf("seed file failed: %v", err) + } + + doc := &dqlshape.Document{ + TypeContext: &typectx.Context{ + PackagePath: "github.com/acme/models/mdp", + }, + Root: map[string]any{ + "Resource": map[string]any{ + "Views": []any{ + map[string]any{ + "Name": "orders", + "ColumnsConfig": map[string]any{ + "ID": map[string]any{"Name": "ID", "DataType": "int"}, + }, + }, + }, + }, + }, + TypeResolutions: []typectx.Resolution{ + { + Expression: "Order", + ResolvedKey: "Order", + Provenance: typectx.Provenance{ + Kind: "registry", + }, + }, + }, + } + + _, err := GenerateFromDQLShape(doc, &Config{ + ProjectDir: projectDir, + PackageDir: "internal/gen", + PackageName: "gen", + FileName: "shapes_gen.go", + AllowedSourceRoots: []string{modelsDir}, + }) + if err != nil { + t.Fatalf("expected provenance enrichment to allow rewrite, got: %v", err) + } +} + +func TestGenerateFromDQLShape_ProvenanceEnrichment_WithGOPATHFallback(t *testing.T) { + root := t.TempDir() + projectDir := filepath.Join(root, "project") + gopath := filepath.Join(root, "gopath") + modelsDir := filepath.Join(gopath, "src", "github.com", "legacy", "models") + if err := os.MkdirAll(filepath.Join(projectDir, "internal", "gen"), 0o755); err != nil { + t.Fatalf("mkdir project failed: %v", err) + } + if err := os.MkdirAll(modelsDir, 0o755); err != nil { + t.Fatalf("mkdir models failed: %v", err) + } + if err := os.WriteFile(filepath.Join(projectDir, "go.mod"), []byte("module example.com/project\n\ngo 1.25\n"), 0o644); err != nil { + t.Fatalf("write project go.mod failed: %v", err) + } + if err := os.WriteFile(filepath.Join(modelsDir, "types.go"), []byte("package models\ntype Legacy struct{}\n"), 0o644); err != nil { + t.Fatalf("write types.go failed: %v", err) + } + dest := filepath.Join(projectDir, "internal", "gen", "shapes_gen.go") + if err := os.WriteFile(dest, []byte("package gen\n"), 0o644); err != nil { + t.Fatalf("seed file failed: %v", err) + } + + orig := os.Getenv("GOPATH") + if err := os.Setenv("GOPATH", gopath); err != nil { + t.Fatalf("set GOPATH failed: %v", err) + } + defer func() { _ = os.Setenv("GOPATH", orig) }() + + doc := &dqlshape.Document{ + TypeContext: &typectx.Context{ + PackagePath: "github.com/legacy/models", + }, + Root: map[string]any{ + "Resource": map[string]any{ + "Views": []any{ + map[string]any{ + "Name": "legacy", + "ColumnsConfig": map[string]any{ + "ID": map[string]any{"Name": "ID", "DataType": "int"}, + }, + }, + }, + }, + }, + TypeResolutions: []typectx.Resolution{ + { + Expression: "Legacy", + ResolvedKey: "Legacy", + Provenance: typectx.Provenance{Kind: "registry"}, + }, + }, + } + _, err := GenerateFromDQLShape(doc, &Config{ + ProjectDir: projectDir, + PackageDir: "internal/gen", + PackageName: "gen", + FileName: "shapes_gen.go", + AllowedSourceRoots: []string{filepath.Join(gopath, "src")}, + UseGoModuleResolve: boolPtr(false), + UseGOPATHFallback: boolPtr(true), + }) + if err != nil { + t.Fatalf("expected GOPATH provenance enrichment to allow rewrite, got: %v", err) + } +} + +func boolPtr(value bool) *bool { + return &value +} + +func KeepCustom() string { return "ok" } +` + if err := os.WriteFile(dest, []byte(initial), 0o644); err != nil { + t.Fatalf("seed file failed: %v", err) + } + + doc := &dqlshape.Document{Root: map[string]any{ + "Resource": map[string]any{ + "Views": []any{ + map[string]any{ + "Name": "orders", + "Schema": map[string]any{ + "Name": "OrderView", + }, + "ColumnsConfig": map[string]any{ + "ID": map[string]any{"Name": "ID", "DataType": "int"}, + }, + }, + }, + }, + }} + _, err := GenerateFromDQLShape(doc, &Config{ + ProjectDir: projectDir, + PackageDir: "internal/gen", + PackageName: "gen", + FileName: "shapes_gen.go", + TypePrefix: "DQL", + }) + if err != nil { + t.Fatalf("generate failed: %v", err) + } + + data, err := os.ReadFile(dest) + if err != nil { + t.Fatalf("read generated file failed: %v", err) + } + source := string(data) + if !strings.Contains(source, "func KeepCustom() string") { + t.Fatalf("expected custom function preserved, got:\n%s", source) + } + if strings.Contains(source, "Old string") { + t.Fatalf("expected old shape declaration replaced, got:\n%s", source) + } + if !strings.Contains(source, "type DQLOrderView struct") || !strings.Contains(source, "Id int") { + t.Fatalf("expected updated shape declaration, got:\n%s", source) + } +} diff --git a/repository/shape/xgen/io.go b/repository/shape/xgen/io.go new file mode 100644 index 000000000..50e86ddaa --- /dev/null +++ b/repository/shape/xgen/io.go @@ -0,0 +1,297 @@ +package xgen + +import ( + "bufio" + "bytes" + "fmt" + "go/ast" + "go/format" + "go/parser" + "go/token" + "os" + "path/filepath" + "sort" + "strings" +) + +func resolvePaths(projectDir, packageDir string) (string, string, error) { + if strings.TrimSpace(projectDir) == "" { + return "", "", fmt.Errorf("shape xgen: project dir was empty") + } + projectDir = filepath.Clean(projectDir) + if strings.TrimSpace(packageDir) == "" { + packageDir = projectDir + } else if !filepath.IsAbs(packageDir) { + packageDir = filepath.Join(projectDir, packageDir) + } + packageDir = filepath.Clean(packageDir) + return projectDir, packageDir, nil +} + +func resolvePackageName(name string, packageDir string) string { + name = strings.TrimSpace(name) + if name != "" { + return name + } + base := filepath.Base(packageDir) + if base == "." || base == string(filepath.Separator) || base == "" { + return "generated" + } + return sanitizePkg(base) +} + +func resolvePackagePath(packagePath, projectDir, packageDir string) (string, error) { + packagePath = strings.TrimSpace(packagePath) + if packagePath != "" { + return packagePath, nil + } + modulePath, err := readModulePath(filepath.Join(projectDir, "go.mod")) + if err != nil { + return "", err + } + rel, err := filepath.Rel(projectDir, packageDir) + if err != nil { + return "", err + } + rel = filepath.ToSlash(rel) + if rel == "." { + return modulePath, nil + } + return strings.TrimRight(modulePath, "/") + "/" + strings.TrimLeft(rel, "/"), nil +} + +func readModulePath(goModPath string) (string, error) { + file, err := os.Open(goModPath) + if err != nil { + return "", fmt.Errorf("shape xgen: open go.mod failed: %w", err) + } + defer file.Close() + scanner := bufio.NewScanner(file) + for scanner.Scan() { + line := strings.TrimSpace(scanner.Text()) + if !strings.HasPrefix(line, "module ") { + continue + } + modulePath := strings.TrimSpace(strings.TrimPrefix(line, "module ")) + if modulePath != "" { + return modulePath, nil + } + } + if err = scanner.Err(); err != nil { + return "", err + } + return "", fmt.Errorf("shape xgen: module path not found in %s", goModPath) +} + +func sanitizePkg(name string) string { + name = strings.TrimSpace(strings.ToLower(name)) + if name == "" { + return "generated" + } + var out strings.Builder + for _, r := range name { + if (r >= 'a' && r <= 'z') || (r >= '0' && r <= '9') || r == '_' { + out.WriteRune(r) + } + } + if out.Len() == 0 { + return "generated" + } + result := out.String() + if result[0] >= '0' && result[0] <= '9' { + return "p" + result + } + return result +} + +func writeAtomic(path string, data []byte, perm os.FileMode) error { + dir := filepath.Dir(path) + temp, err := os.CreateTemp(dir, ".tmp-shape-xgen-*") + if err != nil { + return err + } + tempPath := temp.Name() + cleanup := func() { + _ = os.Remove(tempPath) + } + if _, err = temp.Write(data); err != nil { + _ = temp.Close() + cleanup() + return err + } + if err = temp.Chmod(perm); err != nil { + _ = temp.Close() + cleanup() + return err + } + if err = temp.Close(); err != nil { + cleanup() + return err + } + if err = os.Rename(tempPath, path); err != nil { + cleanup() + return err + } + return nil +} + +func fileExists(path string) (bool, error) { + info, err := os.Stat(path) + if err == nil { + return !info.IsDir(), nil + } + if os.IsNotExist(err) { + return false, nil + } + return false, err +} + +func mergeGeneratedShapes(dest string, generated []byte, typeNames []string) ([]byte, error) { + existing, err := os.ReadFile(dest) + if err != nil { + return nil, err + } + if len(existing) == 0 { + return generated, nil + } + if len(typeNames) == 0 { + return existing, nil + } + + fset := token.NewFileSet() + existingFile, err := parser.ParseFile(fset, dest, existing, parser.ParseComments) + if err != nil { + return nil, fmt.Errorf("shape xgen: parse existing file failed: %w", err) + } + generatedFile, err := parser.ParseFile(token.NewFileSet(), "", generated, parser.ParseComments) + if err != nil { + return nil, fmt.Errorf("shape xgen: parse generated file failed: %w", err) + } + typeNameSet := map[string]bool{} + for _, name := range typeNames { + typeNameSet[name] = true + } + + shapeDecls := generatedShapeDecls(generatedFile, typeNameSet) + if len(shapeDecls) == 0 { + return generated, nil + } + mergedImports := mergeImports(existingFile.Imports, generatedFile.Imports) + + newDecls := make([]ast.Decl, 0, len(existingFile.Decls)+len(shapeDecls)+1) + if len(mergedImports) > 0 { + newDecls = append(newDecls, &ast.GenDecl{ + Tok: token.IMPORT, + Specs: mergedImports, + }) + } + + for _, decl := range existingFile.Decls { + gen, ok := decl.(*ast.GenDecl) + if !ok { + newDecls = append(newDecls, decl) + continue + } + switch gen.Tok { + case token.IMPORT: + continue + case token.TYPE: + filtered := make([]ast.Spec, 0, len(gen.Specs)) + for _, spec := range gen.Specs { + ts, ok := spec.(*ast.TypeSpec) + if !ok || !typeNameSet[ts.Name.Name] { + filtered = append(filtered, spec) + } + } + if len(filtered) == 0 { + continue + } + gen.Specs = filtered + newDecls = append(newDecls, gen) + default: + newDecls = append(newDecls, decl) + } + } + newDecls = append(newDecls, shapeDecls...) + existingFile.Decls = newDecls + existingFile.Imports = importSpecsToImportNodes(mergedImports) + + var out bytes.Buffer + if err = format.Node(&out, fset, existingFile); err != nil { + return nil, fmt.Errorf("shape xgen: format merged file failed: %w", err) + } + return out.Bytes(), nil +} + +func generatedShapeDecls(file *ast.File, typeNameSet map[string]bool) []ast.Decl { + var result []ast.Decl + for _, decl := range file.Decls { + gen, ok := decl.(*ast.GenDecl) + if !ok || gen.Tok != token.TYPE { + continue + } + filtered := make([]ast.Spec, 0, len(gen.Specs)) + for _, spec := range gen.Specs { + ts, ok := spec.(*ast.TypeSpec) + if !ok || !typeNameSet[ts.Name.Name] { + continue + } + filtered = append(filtered, spec) + } + if len(filtered) == 0 { + continue + } + result = append(result, &ast.GenDecl{ + Tok: token.TYPE, + Specs: filtered, + }) + } + return result +} + +func mergeImports(existing []*ast.ImportSpec, generated []*ast.ImportSpec) []ast.Spec { + merged := map[string]*ast.ImportSpec{} + add := func(item *ast.ImportSpec) { + if item == nil || item.Path == nil { + return + } + key := item.Path.Value + "|" + importAlias(item) + if _, ok := merged[key]; ok { + return + } + merged[key] = item + } + for _, item := range existing { + add(item) + } + for _, item := range generated { + add(item) + } + keys := make([]string, 0, len(merged)) + for key := range merged { + keys = append(keys, key) + } + sort.Strings(keys) + result := make([]ast.Spec, 0, len(keys)) + for _, key := range keys { + result = append(result, merged[key]) + } + return result +} + +func importAlias(item *ast.ImportSpec) string { + if item == nil || item.Name == nil { + return "" + } + return item.Name.Name +} + +func importSpecsToImportNodes(specs []ast.Spec) []*ast.ImportSpec { + result := make([]*ast.ImportSpec, 0, len(specs)) + for _, spec := range specs { + if item, ok := spec.(*ast.ImportSpec); ok { + result = append(result, item) + } + } + return result +} diff --git a/repository/shape/xgen/model.go b/repository/shape/xgen/model.go new file mode 100644 index 000000000..623f79be8 --- /dev/null +++ b/repository/shape/xgen/model.go @@ -0,0 +1,70 @@ +package xgen + +import "github.com/viant/x" + +type ( + ViewTypeContext struct { + ViewName string + SchemaName string + } + + RouteTypeContext struct { + RouteName string + RouteURI string + RouteRef string + TypeName string + } +) + +// Config controls shape->Go generation. +type Config struct { + // ProjectDir points to target Go project root. + ProjectDir string + // PackageDir points to package directory inside the project (relative or absolute). + PackageDir string + // PackageName sets generated package name; defaults to basename(PackageDir). + PackageName string + // PackagePath sets fully-qualified import path; when empty it's derived from go.mod + PackageDir. + PackagePath string + // FileName sets generated filename; defaults to shapes_gen.go. + FileName string + // TypePrefix prefixes generated type names. + TypePrefix string + // ViewSuffix appends suffix to generated view type names when schema name is absent. + ViewSuffix string + // InputSuffix appends suffix to generated route input type names when explicit type name is absent. + InputSuffix string + // OutputSuffix appends suffix to generated route output type names when explicit type name is absent. + OutputSuffix string + // ViewTypeNamer customizes final view type name. + ViewTypeNamer func(ctx ViewTypeContext) string + // InputTypeNamer customizes final input type name. + InputTypeNamer func(ctx RouteTypeContext) string + // OutputTypeNamer customizes final output type name. + OutputTypeNamer func(ctx RouteTypeContext) string + // Registry allows reusing an external viant/x registry. + Registry *x.Registry + // AllowUnsafeRewrite allows overwriting existing generated files even when + // type provenance indicates unresolved/unsafe origins. Default false. + AllowUnsafeRewrite bool + // AllowedProvenanceKinds controls which provenance kinds are trusted for updates. + // Defaults to builtin, resource_type and ast_type. + AllowedProvenanceKinds []string + // AllowedSourceRoots controls additional trusted roots for provenance files. + // ProjectDir is always implicitly trusted. + AllowedSourceRoots []string + // UseGoModuleResolve enables go.mod + replace-based source resolution. Default true. + UseGoModuleResolve *bool + // UseGOPATHFallback enables GOPATH/src fallback when go.mod resolution misses. Default true. + UseGOPATHFallback *bool + // StrictProvenance blocks updates on policy violations. Default true. + StrictProvenance *bool +} + +// Result captures generation outputs. +type Result struct { + FilePath string + PackagePath string + PackageName string + Types []string +} diff --git a/router.go b/router.go new file mode 100644 index 000000000..03504a4eb --- /dev/null +++ b/router.go @@ -0,0 +1,130 @@ +package datly + +import ( + "context" + "errors" + "fmt" + "net/http" + "strconv" + + "github.com/viant/datly/repository" + "github.com/viant/datly/repository/contract" + "github.com/viant/xdatly/handler/response" + hstate "github.com/viant/xdatly/handler/state" +) + +type Handler[T any] func(ctx context.Context, service T, request *http.Request, injector hstate.Injector, extra ...OperateOption) (interface{}, error) + +type Route[T any] struct { + dao *Service + handler Handler[T] + service T + path *contract.Path + component *repository.Component +} + +func (r Route[T]) ensureComponent(ctx context.Context) (*repository.Component, error) { + if r.component == nil { + var err error + r.component, err = r.dao.repository.Registry().Lookup(ctx, r.path) + if err != nil { + return nil, err + } + } + return r.component, nil +} + +func (r Route[T]) Run(ctx context.Context, writer http.ResponseWriter, request *http.Request) error { + marshaller, contentType, _, err := r.dao.getMarshaller(request, r.component) + if err != nil { + return fmt.Errorf("failed to lookup marshaller: %w", err) + } + injector, err := r.dao.GetInjector(request, r.component) + if err != nil { + return fmt.Errorf("failed to lookup injector: %w", err) + } + selectors := []*hstate.NamedQuerySelector{} + values := request.URL.Query() + if page := values.Get("page"); page != "" { + selector := &hstate.NamedQuerySelector{Name: r.component.View.Name} + selector.Page, _ = strconv.Atoi(page) + selectors = append(selectors, selector) + } + result, err := r.handler(ctx, r.service, request, injector, WithSessionOptions(WithRequest(request), WithQuerySelectors(selectors...))) + var data []byte + if err != nil { + rErr, ok := err.(*response.Error) + if !ok { + rErr = response.NewError(http.StatusInternalServerError, err.Error()) + } + data, err = marshaller(rErr) + } else { + data, err = marshaller(result) + } + if err != nil { + http.Error(writer, err.Error(), http.StatusInternalServerError) + return nil + } + statusCode := http.StatusOK + statusCoder, ok := result.(response.StatusCoder) + if ok { + statusCode = statusCoder.StatusCode() + } + + writer.Header().Set("Content-Type", contentType) + writer.WriteHeader(statusCode) + _, err = writer.Write(data) + return err +} + +func newRoute[T any](dao *Service, path *contract.Path, component *repository.Component, service T, handler Handler[T]) *Route[T] { + return &Route[T]{path: path, handler: handler, dao: dao, component: component, service: service} +} + +type Router[T any] struct { + registry map[string]*Route[T] + dao *Service + service T +} + +type routeNotFound struct { + error +} + +// IsRouteNotFound checks if error is route not found +func IsRouteNotFound(err error) bool { + _, ok := err.(*routeNotFound) + return ok +} + +func (r *Router[T]) Run(writer http.ResponseWriter, request *http.Request) error { + aPath := contract.NewPath(request.Method, request.URL.Path) + component, err := r.dao.repository.Registry().Lookup(request.Context(), aPath) + if err != nil { + fmt.Println(err) + return &routeNotFound{err} + } + route, ok := r.registry[component.Path.Key()] + if !ok { + return &routeNotFound{errors.New("route not found")} + } + return route.Run(request.Context(), writer, request) +} + +func (r *Router[T]) Register(ctx context.Context, path *contract.Path, handler Handler[T]) error { + component, err := r.dao.repository.Registry().Lookup(ctx, path) + if err != nil { + return fmt.Errorf("failed to lookup component: %w for path: %+v", err, path) + } + route := newRoute[T](r.dao, path, component, r.service, handler) + r.registry[path.Key()] = route + return nil +} + +func NewRouter[T any](dao *Service, service T) *Router[T] { + return &Router[T]{registry: make(map[string]*Route[T]), dao: dao, service: service} +} + +type BodyEnvelope[T any] struct { + Body T `parameter:",kind=body"` +} diff --git a/service.go b/service.go index 68f537d55..96e3d609d 100644 --- a/service.go +++ b/service.go @@ -4,33 +4,39 @@ import ( "context" _ "embed" "fmt" + "github.com/viant/cloudless/async/mbus" "github.com/viant/datly/gateway" "github.com/viant/datly/repository" + rcontent "github.com/viant/datly/repository/content" "github.com/viant/datly/repository/contract" "github.com/viant/datly/repository/locator/component/dispatcher" + srv "github.com/viant/datly/service" sjwt "github.com/viant/datly/service/auth/jwt" "github.com/viant/datly/service/auth/mock" "github.com/viant/datly/service/executor" "github.com/viant/datly/service/operator" "github.com/viant/datly/service/reader" "github.com/viant/datly/service/session" + "github.com/viant/datly/shared" "github.com/viant/datly/view" "github.com/viant/datly/view/extension" + "github.com/viant/datly/view/state/kind/locator" verifier2 "github.com/viant/scy/auth/jwt/verifier" hstate "github.com/viant/xdatly/handler/state" + "net/http" + nurl "net/url" + "reflect" + "strings" + "time" + "github.com/viant/datly/view/state" "github.com/viant/scy/auth/jwt" "github.com/viant/scy/auth/jwt/signer" "github.com/viant/structology" "github.com/viant/xdatly/codec" xhandler "github.com/viant/xdatly/handler" - "net/http" - nurl "net/url" - "reflect" - "strings" - "time" ) //go:embed Version @@ -50,16 +56,18 @@ type ( } sessionOptions struct { - request *http.Request - resource state.Resource - form *hstate.Form + request *http.Request + resource state.Resource + form *hstate.Form + querySelectors []*hstate.NamedQuerySelector } SessionOption func(o *sessionOptions) operateOptions struct { - path *contract.Path - component *repository.Component - session *session.Session + path *contract.Path + component *repository.Component + session *session.Session + output interface{} input interface{} sessionOptions []SessionOption @@ -151,6 +159,12 @@ func WithForm(form *hstate.Form) SessionOption { } } +func WithQuerySelectors(selectors ...*hstate.NamedQuerySelector) SessionOption { + return func(o *sessionOptions) { + o.querySelectors = selectors + } +} + func WithStateResource(resource state.Resource) SessionOption { return func(o *sessionOptions) { o.resource = resource @@ -160,8 +174,12 @@ func WithStateResource(resource state.Resource) SessionOption { func (s *Service) NewComponentSession(aComponent *repository.Component, opts ...SessionOption) *session.Session { sessionOpt := newSessionOptions(opts) options := aComponent.LocatorOptions(sessionOpt.request, sessionOpt.form, aComponent.UnmarshalFunc(sessionOpt.request)) + if sessionOpt.querySelectors != nil { + options = append(options, locator.WithQuerySelectors(sessionOpt.querySelectors)) + } aSession := session.New(aComponent.View, session.WithLocatorOptions(options...), session.WithAuth(s.repository.Auth()), + session.WithComponent(aComponent), session.WithStateResource(sessionOpt.resource), session.WithOperate(s.operator.Operate)) return aSession } @@ -193,7 +211,12 @@ func (s *Service) SignRequest(request *http.Request, claims *jwt.Claims) error { func LoadInput(ctx context.Context, aSession *session.Session, aComponent *repository.Component, input interface{}) error { ctx = aSession.Context(ctx, false) - if err := aSession.LoadState(aComponent.Input.Type.Parameters, input); err != nil { + if err := aSession.LoadState( + aComponent.Input.Type.Parameters, + input, + session.WithHasMarker(), + session.WithValuePresenceFallback(), + ); err != nil { return err } if err := aSession.Populate(ctx); err != nil { @@ -268,6 +291,115 @@ func (s *Service) PopulateInput(ctx context.Context, aComponent *repository.Comp return nil } +func (s *Service) GetInjector(r *http.Request, comp *repository.Component) (hstate.Injector, error) { + if err := s.ensureComponentInitialized(comp); err != nil { + return nil, err + } + // Build component session to populate state (for exclusion filters) + sess := s.NewComponentSession(comp, WithRequest(r), WithStateResource(comp.View.Resource())) + return sess, nil +} + +// GetMarshaller prepares a request-scoped marshaller closure and resolved content type for the given component path. +// It preserves existing behavior for readers (format derived from query) and defaults to JSON otherwise. +func (s *Service) GetMarshaller(r *http.Request, methodAndPath string, extra ...repository.MarshalOption) (marshal shared.Marshal, contentType string, comp *repository.Component, err error) { + comp, err = s.Component(r.Context(), methodAndPath) + if err != nil || comp == nil { + if err == nil { + err = fmt.Errorf("component not found: %s", methodAndPath) + } + return nil, "", nil, err + } + return s.getMarshaller(r, comp, extra...) +} + +func (s *Service) getMarshaller(r *http.Request, comp *repository.Component, extra ...repository.MarshalOption) (shared.Marshal, string, *repository.Component, error) { + // Ensure component content marshallers are initialized (defensive when invoked outside router lifecycle) + if err := s.ensureComponentInitialized(comp); err != nil { + return nil, "", nil, err + } + + // Build component session to populate state (for exclusion filters) + sess := s.NewComponentSession(comp, WithRequest(r), WithStateResource(comp.View.Resource())) + // Compute JSON field filters from populated state + filters := comp.Exclusion(sess.State()) + + // Optional format override from query parameter `format` + override := strings.TrimSpace(r.URL.Query().Get("format")) + + var opts []repository.MarshalOption + opts = append(opts, repository.WithRequest(r), repository.WithFilters(filters)) + if override != "" { + opts = append(opts, repository.WithFormat(override)) + } + if len(extra) > 0 { + opts = append(opts, extra...) + } + + // Prepare marshaller closure + marshal := comp.MarshalFunc(opts...) + + // Resolve content type for headers + resolved := override + if resolved == "" && comp.Service == srv.TypeReader { + resolved = comp.Output.Format(r.URL.Query()) + } + if resolved == "" { + resolved = rcontent.JSONFormat + } + contentType := comp.Output.ContentType(resolved) + return marshal, contentType, comp, nil +} + +// GetUnmarshaller prepares a request-scoped unmarshaller for the given component path. +func (s *Service) GetUnmarshaller(r *http.Request, methodAndPath string, extra ...repository.UnmarshalOption) (unmarshal shared.Unmarshal, comp *repository.Component, err error) { + comp, err = s.Component(r.Context(), methodAndPath) + if err != nil || comp == nil { + if err == nil { + err = fmt.Errorf("component not found: %s", methodAndPath) + } + return nil, nil, err + } + return s.getUnmarshaller(r, comp, extra...) +} + +func (s *Service) getUnmarshaller(r *http.Request, comp *repository.Component, extra ...repository.UnmarshalOption) (shared.Unmarshal, *repository.Component, error) { + // Ensure component content marshallers are initialized (defensive) + if err := s.ensureComponentInitialized(comp); err != nil { + return nil, nil, err + } + var opts []repository.UnmarshalOption + opts = append(opts, repository.WithUnmarshalRequest(r)) + if len(extra) > 0 { + opts = append(opts, extra...) + } + unmarshal := comp.UnmarshalFor(opts...) + return unmarshal, comp, nil +} + +// ensureComponentInitialized defensively initializes component content marshallers when called from external contexts. +func (s *Service) ensureComponentInitialized(comp *repository.Component) error { + if comp == nil { + return fmt.Errorf("component was nil") + } + res := comp.View.GetResource() + if res == nil { + return nil + } + // If JSON marshaller already present, assume initialized. + if comp.Content.Marshaller.JSON.JsonMarshaller != nil { + return nil + } + // Initialize content marshallers as in Component.Init + if err := comp.Content.InitMarshaller(comp.IOConfig(), comp.Output.Exclude, comp.BodyType(), comp.OutputType()); err != nil { + return err + } + if err := comp.Content.Marshaller.Init(res.LookupType()); err != nil { + return err + } + return nil +} + // Read reads data from a view func (s *Service) Read(ctx context.Context, locator string, dest interface{}, option ...reader.Option) error { aView, err := s.View(ctx, wrapWithMethod(http.MethodGet, locator)) @@ -517,7 +649,7 @@ func (s *Service) HTTPHandler(ctx context.Context, options ...gateway.Option) (h return s.handler, nil } -// New creates a datly service, repository allows you to bootstrap empty or existing yaml repository +// New creates a dao dao, repository allows you to bootstrap empty or existing yaml repository func New(ctx context.Context, options ...repository.Option) (*Service, error) { options = append([]repository.Option{ repository.WithJWTSigner(mock.HmacJwtSigner()), diff --git a/service/executor/expand/data_unit.go b/service/executor/expand/data_unit.go index 66a694dad..9d2c22dad 100644 --- a/service/executor/expand/data_unit.go +++ b/service/executor/expand/data_unit.go @@ -17,23 +17,25 @@ import ( type ( DataUnit struct { - Columns codec.ColumnsSource - ParamsGroup []interface{} - Mock bool - TemplateSQL string - MetaSource Dber `velty:"-"` - Statements *Statements `velty:"-"` - + Columns codec.ColumnsSource + ParamsGroup []interface{} + Mock bool + TemplateSQL string + MetaSource Dber `velty:"-"` + Statements *Statements `velty:"-"` mu sync.Mutex `velty:"-"` placeholderCounter int `velty:"-"` sqlxValidator *validator.Service `velty:"-"` sliceIndex map[reflect.Type]*xunsafe.Slice `velty:"-"` ctx context.Context `velty:"-"` + EvalLock sync.Mutex } ExecutablesIndex map[string]*Executable ) +// + func (c *DataUnit) WithPresence() interface{} { var opt interface{} = validator.WithSetMarker() return opt @@ -43,17 +45,6 @@ func (c *DataUnit) WithLocation(loc string) interface{} { return opt } -// Reset clears binding-related state so DataUnit can be safely reused for a new evaluation -func (c *DataUnit) Reset() { - c.mu.Lock() - c.placeholderCounter = 0 - if len(c.ParamsGroup) > 0 { - c.ParamsGroup = c.ParamsGroup[:0] - } - c.TemplateSQL = "" - c.mu.Unlock() -} - func (c *DataUnit) Validate(dest interface{}, opts ...interface{}) (*validator.Validation, error) { db, err := c.MetaSource.Db() if err != nil { @@ -157,7 +148,7 @@ func (c *DataUnit) Next() (interface{}, error) { return c.ParamsGroup[index], nil } - return nil, fmt.Errorf("expected to get binding parameter, but noone was found, ParamsGroup: %v, placeholderCounter: %v", c.ParamsGroup, c.placeholderCounter) + return nil, fmt.Errorf("expected to get binding parameter, but none was found, ParamsGroup: %v, placeholderCounter: %v", c.ParamsGroup, c.placeholderCounter) } func (c *DataUnit) ensureSliceIndex() { @@ -187,6 +178,12 @@ func (c *DataUnit) addAll(args ...interface{}) { c.mu.Unlock() } +func (c *DataUnit) Shrink(offset int) { + c.mu.Lock() + c.ParamsGroup = c.ParamsGroup[:offset] + c.mu.Unlock() +} + func (c *DataUnit) IsServiceExec(SQL string) (*Executable, bool) { return c.Statements.LookupExecutable(SQL) } @@ -278,6 +275,17 @@ func (c *DataUnit) Like(columnName string, args interface{}) (string, error) { func (c *DataUnit) NotLike(columnName string, args interface{}) (string, error) { return c.like(columnName, args, false) } +func (c *DataUnit) Expression(expr string, value interface{}) (string, error) { + return c.expression(expr, value) +} + +func (c *DataUnit) expression(expr string, value interface{}) (string, error) { + if value == "" { + return "", nil + } + c.addAll(value) + return expr, nil +} func (c *DataUnit) like(columnName string, args interface{}, inclusive bool) (string, error) { expander, err := bindingsCache.Lookup(args) diff --git a/service/executor/expand/evaluator.go b/service/executor/expand/evaluator.go index d57b935b2..0c6dee479 100644 --- a/service/executor/expand/evaluator.go +++ b/service/executor/expand/evaluator.go @@ -35,7 +35,12 @@ type ( func WithCustomContexts(ctx ...*Variable) EvaluatorOption { return func(c *config) { - c.embededTypes = append(c.embededTypes, ctx...) + for _, item := range ctx { + if item == nil { + continue + } + c.embededTypes = append(c.embededTypes, item) + } } } @@ -47,7 +52,12 @@ func WithContext(ctx context.Context) EvaluatorOption { func WithVariable(namedVariable ...*NamedVariable) EvaluatorOption { return func(c *config) { - c.namedVariables = append(c.namedVariables, namedVariable...) + for _, item := range namedVariable { + if item == nil { + continue + } + c.namedVariables = append(c.namedVariables, item) + } } } @@ -65,6 +75,9 @@ func WithSetLiteral(setLiterals func(state *structology.State) error) EvaluatorO func WithTypeLookup(lookup xreflect.LookupType) EvaluatorOption { return func(c *config) { + if lookup == nil { + return + } c.typeLookup = lookup } } @@ -141,12 +154,18 @@ func NewEvaluator(template string, options ...EvaluatorOption) (*Evaluator, erro } for _, valueType := range aConfig.embededTypes { + if valueType == nil { + continue + } if err = evaluator.planner.EmbedVariable(valueType.Type); err != nil { return nil, err } } for _, variable := range aConfig.namedVariables { + if variable == nil { + continue + } if err = evaluator.planner.DefineVariable(variable.Name, variable.Type); err != nil { return nil, err } @@ -181,6 +200,9 @@ func NewEvaluator(template string, options ...EvaluatorOption) (*Evaluator, erro func createConfig(options []EvaluatorOption) *config { instance := newConfig() for _, option := range options { + if option == nil { + continue + } option(instance) } @@ -252,7 +274,7 @@ func (e *Evaluator) ensureState(ctx *Context, options ...StateOption) *State { state.Context = ctx } - state.Init(e.stateProvider(), e.predicateConfigs, options...) + state.Init(e.stateProvider(), e.predicateConfigs, e.stateType, options...) return state } diff --git a/service/executor/expand/evaluator_test.go b/service/executor/expand/evaluator_test.go new file mode 100644 index 000000000..b6e0a7e2d --- /dev/null +++ b/service/executor/expand/evaluator_test.go @@ -0,0 +1,78 @@ +package expand_test + +import ( + "testing" + + "github.com/viant/datly/service/executor/expand" +) + +func TestNewEvaluator_DefaultTypeLookup(t *testing.T) { + evaluator, err := expand.NewEvaluator(`#set($x = $New("int"))$x`) + if err != nil { + t.Fatalf("expected no error, got %v", err) + } + if _, err := evaluator.Evaluate(nil); err != nil { + t.Fatalf("expected no error, got %v", err) + } +} + +func TestNewEvaluator_WithNilTypeLookupOption(t *testing.T) { + defer func() { + if r := recover(); r != nil { + t.Fatalf("expected no panic, got %v", r) + } + }() + + evaluator, err := expand.NewEvaluator(`#set($x = $New("int"))$x`, expand.WithTypeLookup(nil)) + if err != nil { + t.Fatalf("expected no error, got %v", err) + } + if _, err := evaluator.Evaluate(nil); err != nil { + t.Fatalf("expected no error, got %v", err) + } +} + +func TestNewEvaluator_UnknownTypeReturnsError(t *testing.T) { + defer func() { + if r := recover(); r != nil { + t.Fatalf("expected no panic, got %v", r) + } + }() + + _, err := expand.NewEvaluator(`#set($x = $New("DefinitelyNotAType"))$x`, expand.WithTypeLookup(nil)) + if err == nil { + t.Fatalf("expected error") + } +} + +func TestNewEvaluator_WithNilNamedVariableOption(t *testing.T) { + defer func() { + if r := recover(); r != nil { + t.Fatalf("expected no panic, got %v", r) + } + }() + + evaluator, err := expand.NewEvaluator(`ok`, expand.WithVariable(nil)) + if err != nil { + t.Fatalf("expected no error, got %v", err) + } + if _, err := evaluator.Evaluate(nil); err != nil { + t.Fatalf("expected no error, got %v", err) + } +} + +func TestNewEvaluator_WithNilCustomContextOption(t *testing.T) { + defer func() { + if r := recover(); r != nil { + t.Fatalf("expected no panic, got %v", r) + } + }() + + evaluator, err := expand.NewEvaluator(`ok`, expand.WithCustomContexts(nil)) + if err != nil { + t.Fatalf("expected no error, got %v", err) + } + if _, err := evaluator.Evaluate(nil); err != nil { + t.Fatalf("expected no error, got %v", err) + } +} diff --git a/service/executor/expand/fn_new.go b/service/executor/expand/fn_new.go index 6f5b78286..6cbdf10ed 100644 --- a/service/executor/expand/fn_new.go +++ b/service/executor/expand/fn_new.go @@ -41,7 +41,7 @@ func (n *newer) NewResultType(call *expr.Call) (reflect.Type, error) { expression, ok := call.Args[0].(*expr.Literal) if !ok { - return nil, fmt.Errorf("expected arg to be type of %T but was %T", expression, call.Args[1]) + return nil, fmt.Errorf("expected arg to be type of %T but was %T", expression, call.Args[0]) } return types.LookupType(n.lookup, expression.Value) diff --git a/service/executor/expand/fn_printer.go b/service/executor/expand/fn_printer.go index 3eaabe2c9..620e7ef57 100644 --- a/service/executor/expand/fn_printer.go +++ b/service/executor/expand/fn_printer.go @@ -61,7 +61,7 @@ func (p *Printer) Println(args ...interface{}) string { func (p *Printer) Printf(format string, args ...interface{}) string { p.derefArgs(args) - fmt.Printf(p.Sprintf(format, args...)) + fmt.Print(p.Sprintf(format, args...)) return "" } @@ -107,12 +107,12 @@ func (p *Printer) Fatal(any interface{}, args ...interface{}) (string, error) { format, ok := any.(string) if ok { - return "", fmt.Errorf(p.Sprintf(format, args...)) + return "", fmt.Errorf("%s", p.Sprintf(format, args...)) } if err, ok := any.(error); ok { return "", err } - return "", fmt.Errorf(p.Sprintf("%+v", any)) + return "", fmt.Errorf("%s", p.Sprintf("%+v", any)) } // Fatalf fatal with formatting @@ -124,7 +124,7 @@ func (p *Printer) Fatalf(any interface{}, args ...interface{}) (string, error) { func (p *Printer) FatalfWithCode(code int, any interface{}, args ...interface{}) (string, error) { format, ok := any.(string) if ok { - return "", response.NewError(code, fmt.Sprintf(p.Sprintf(format, args...))) + return "", response.NewError(code, p.Sprintf(format, args...)) } if err, ok := any.(error); ok { return "", response.NewError(code, err.Error(), response.WithError(err)) diff --git a/service/executor/expand/predicate.go b/service/executor/expand/predicate.go index b67e83312..aae51c8bb 100644 --- a/service/executor/expand/predicate.go +++ b/service/executor/expand/predicate.go @@ -38,7 +38,11 @@ type ( } ) -func NewPredicate(ctx *Context, state *structology.State, config []*PredicateConfig) *Predicate { +func NewPredicate(ctx *Context, state *structology.State, config []*PredicateConfig, stateType *structology.StateType) *Predicate { + // Initialize state if not provided, but never override an existing state + if state == nil && stateType != nil { + state = stateType.NewState() + } return &Predicate{ ctx: ctx, config: config, @@ -129,6 +133,10 @@ func (p *Predicate) expand(group int, operator string) (string, error) { } ctx = vcontext.WithValue(ctx, PredicateCtx, p.ctx) ctx = vcontext.WithValue(ctx, PredicateState, p.state) + + p.ctx.DataUnit.EvalLock.Lock() + defer p.ctx.DataUnit.EvalLock.Unlock() + if p.ctx.Session != nil { aLogger := p.ctx.Session.Logger() ctx = vcontext.WithValue(ctx, logger.ContextKey, aLogger) diff --git a/service/executor/expand/state.go b/service/executor/expand/state.go index f970ff679..6399be43d 100644 --- a/service/executor/expand/state.go +++ b/service/executor/expand/state.go @@ -2,6 +2,7 @@ package expand import ( "context" + "github.com/viant/datly/service/executor/extension" "github.com/viant/datly/view/state/predicate" @@ -83,7 +84,7 @@ func WithCustomContext(customContext *Variable) StateOption { } } -func (s *State) Init(templateState *est.State, predicates []*PredicateConfig, options ...StateOption) { +func (s *State) Init(templateState *est.State, predicates []*PredicateConfig, stateType *structology.StateType, options ...StateOption) { for _, option := range options { option(s) } @@ -103,8 +104,6 @@ func (s *State) Init(templateState *est.State, predicates []*PredicateConfig, op if s.DataUnit == nil { s.DataUnit = NewDataUnit(nil) } - // Ensure bindings/cursor are reset for a fresh evaluation cycle - s.DataUnit.Reset() if s.Http == nil { s.Http = &Http{} @@ -122,7 +121,7 @@ func (s *State) Init(templateState *est.State, predicates []*PredicateConfig, op s.MessageBus = s.Session.MessageBus() } - s.Predicate = NewPredicate(s.Context, s.ParametersState, predicates) + s.Predicate = NewPredicate(s.Context, s.ParametersState, predicates, stateType) s.State = templateState } @@ -149,6 +148,6 @@ func StateWithSQL(ctx context.Context, SQL string) *State { Context: &Context{Context: ctx}, } - aState.Init(nil, nil) + aState.Init(nil, nil, nil) return aState } diff --git a/service/executor/extension/session.go b/service/executor/extension/session.go index d52f72f5d..9fbfe51fc 100644 --- a/service/executor/extension/session.go +++ b/service/executor/extension/session.go @@ -19,7 +19,7 @@ import ( type ( Session struct { sqlService SqlServiceFn - stater state.Stater + injector state.Injector validator *validator.Service differ *differ.Service mbus *xmbus.Service @@ -92,7 +92,7 @@ func (s *Session) Db(opts ...sqlx.Option) (*sqlx.Service, error) { } func (s *Session) Stater() *state.Service { - return state.New(s.stater) + return state.New(s.injector) } func (s *Session) FlushTemplate(ctx context.Context) error { @@ -148,8 +148,8 @@ func WithMessageBus(messageBusses []*mbus.Resource) Option { } } -func WithStater(stater state.Stater) Option { +func WithStater(injector state.Injector) Option { return func(s *Session) { - s.stater = stater + s.injector = injector } } diff --git a/service/executor/extension/validator.go b/service/executor/extension/validator.go index 8c23f0df4..740b16f37 100644 --- a/service/executor/extension/validator.go +++ b/service/executor/extension/validator.go @@ -4,6 +4,8 @@ import ( "context" "database/sql" "fmt" + + derrors "github.com/viant/datly/utils/errors" "github.com/viant/datly/utils/httputils" "github.com/viant/govalidator" sqlxvalidator "github.com/viant/sqlx/io/validator" @@ -33,6 +35,9 @@ func (v *SqlxValidator) Validate(ctx context.Context, any interface{}, opts ...v err = v.validator.validateWithSqlx(ctx, any, validation, options) } if err != nil { + if derrors.IsDatabaseError(err) { + return validation, err + } validation.Append("/", "", "", "error", err.Error()) } return validation, nil @@ -46,9 +51,15 @@ func (v *Validator) Validate(ctx context.Context, any interface{}, opts ...valid validation := getOrCreateValidation(options) err := v.validateWithGoValidator(ctx, any, validation, options) if err != nil { + if derrors.IsDatabaseError(err) { + return validation, err + } validation.Append("/", "", "", "error", err.Error()) } if err = v.validateWithSqlx(ctx, any, validation, options); err != nil { + if derrors.IsDatabaseError(err) { + return validation, err + } validation.Append("/", "", "", "error", err.Error()) } return validation, nil diff --git a/service/executor/handler/executor.go b/service/executor/handler/executor.go index 4d5e719ae..70f1746ba 100644 --- a/service/executor/handler/executor.go +++ b/service/executor/handler/executor.go @@ -4,6 +4,8 @@ import ( "context" "database/sql" "fmt" + "net/http" + "github.com/viant/datly/repository" "github.com/viant/datly/repository/contract" executor "github.com/viant/datly/service/executor" @@ -20,7 +22,6 @@ import ( "github.com/viant/xdatly/handler/sqlx" hstate "github.com/viant/xdatly/handler/state" "github.com/viant/xdatly/handler/validator" - "net/http" ) type ( @@ -94,6 +95,12 @@ func (e *Executor) Session(ctx context.Context) (*executor.Session, error) { e.executorSession = sess sess.SessionHandler = sessionHandler + // inherit tx from session options if available + if e.tx == nil { + if tx := e.session.Options.SqlTx(); tx != nil { + e.tx = tx + } + } return e.executorSession, err } @@ -126,6 +133,9 @@ func (e *Executor) newSession(aSession *session.Session, opts ...Option) *extens if options.auth != nil { e.auth = options.auth } + if e.logger == nil { + e.logger = options.logger + } res := e.view.GetResource() sess := extension.NewSession( extension.WithTemplateFlush(func(ctx context.Context) error { @@ -135,6 +145,7 @@ func (e *Executor) newSession(aSession *session.Session, opts ...Option) *extens extension.WithRedirect(e.redirect), extension.WithSql(e.newSqlService), extension.WithHttp(e.newHttp), + extension.WithLogger(e.logger), extension.WithAuth(e.newAuth), extension.WithMessageBus(res.MessageBuses), ) @@ -157,6 +168,10 @@ func (e *Executor) newSqlService(options *sqlx.Options) (sqlx.Sqlx, error) { if unit == e.dataUnit { //we are using View that can contain SQL Statements in Velty txStartedNotifier = e.txStarted } + // default SQLx tx to executor tx to avoid internal Begin/Commit if caller provided one + if options.WithTx == nil && e.tx != nil { + options.WithTx = e.tx + } return &Service{ txNotifier: txStartedNotifier, dataUnit: unit, @@ -168,6 +183,7 @@ func (e *Executor) newSqlService(options *sqlx.Options) (sqlx.Sqlx, error) { } func (e *Executor) getDataUnit(options *sqlx.Options) (*expand.DataUnit, error) { + e.ensureConnectors() if (options.WithDb == nil && options.WithTx == nil) && options.WithConnector == e.view.Connector.Name { return e.dataUnit, nil } @@ -192,6 +208,11 @@ func (e *Executor) getDataUnit(options *sqlx.Options) (*expand.DataUnit, error) if connector == nil { return nil, fmt.Errorf("failed to lookup connector %v", options.WithConnector) } + + if _, ok := e.connectors[options.WithConnector]; !ok { + e.connectors[options.WithConnector] = connector + } + db, err := connector.DB() if err != nil { return nil, err @@ -206,6 +227,17 @@ func (e *Executor) getDataUnit(options *sqlx.Options) (*expand.DataUnit, error) return e.dataUnit, nil } +func (e *Executor) ensureConnectors() { + if len(e.connectors) == 0 { + e.connectors = make(view.Connectors) + if res := e.view.GetResource(); res != nil { + for _, connector := range res.Connectors { + e.connectors[connector.Name] = connector + } + } + } +} + func (e *Executor) Execute(ctx context.Context) error { if e.executed { return nil @@ -217,6 +249,10 @@ func (e *Executor) Execute(ctx context.Context) error { dbOptions = append(dbOptions, executor.WithTx(e.tx)) } + err := service.ExecuteStmts(ctx, executor.NewViewDBSource(e.view), newSqlxIterator(e.dataUnit.Statements.Executable), dbOptions...) + if err != nil { + return err + } for _, unit := range e.dataUnits { dbSource := &DbSource{} dbSource.db, _ = unit.MetaSource.Db() @@ -225,7 +261,7 @@ func (e *Executor) Execute(ctx context.Context) error { } } - return service.ExecuteStmts(ctx, executor.NewViewDBSource(e.view), newSqlxIterator(e.dataUnit.Statements.Executable), dbOptions...) + return err } func (e *Executor) ExpandAndExecute(ctx context.Context) (*executor.Session, error) { @@ -262,7 +298,6 @@ func (e *Executor) redirect(ctx context.Context, route *http2.Route, opts ...hst request.Header = originalRequest.Header } stateOptions := hstate.NewOptions(opts...) - unmarshal := aComponent.UnmarshalFunc(request) locatorOptions := append(aComponent.LocatorOptions(request, hstate.NewForm(), unmarshal)) if stateOptions.Query() != nil { @@ -286,8 +321,13 @@ func (e *Executor) redirect(ctx context.Context, route *http2.Route, opts ...hst session.WithOperate(e.session.Options.Operate()), session.WithTypes(&aComponent.Contract.Input.Type, &aComponent.Contract.Output.Type), session.WithComponent(aComponent), + session.WithLogger(e.logger), session.WithRegistry(registry), ) + if tx := stateOptions.SqlTx(); tx != nil { + // associate tx with session; child executor will reuse it + aSession.Apply(session.WithSQLTx(tx)) + } err = aSession.InitKinds(state.KindComponent, state.KindHeader, state.KindRequestBody, state.KindForm, state.KindQuery) if err != nil { @@ -295,7 +335,11 @@ func (e *Executor) redirect(ctx context.Context, route *http2.Route, opts ...hst } ctx = aSession.Context(ctx, true) anExecutor := NewExecutor(aComponent.View, aSession) - return anExecutor.NewHandlerSession(ctx) + // ensure Execute(ctx) uses the provided tx (avoid autocommit) + if tx := stateOptions.SqlTx(); tx != nil { + anExecutor.tx = tx + } + return anExecutor.NewHandlerSession(ctx, WithLogger(aSession.Logger())) } func (e *Executor) newHttp() http2.Http { diff --git a/service/executor/handler/locator/handler.go b/service/executor/handler/locator/handler.go index 3758b8abd..6f0c31aeb 100644 --- a/service/executor/handler/locator/handler.go +++ b/service/executor/handler/locator/handler.go @@ -11,6 +11,7 @@ import ( "github.com/viant/datly/view/state" "github.com/viant/datly/view/state/kind" "github.com/viant/datly/view/state/kind/locator" + "reflect" ) type Handler struct { @@ -22,7 +23,7 @@ func (v *Handler) Names() []string { return nil } -func (v *Handler) Value(ctx context.Context, name string) (interface{}, bool, error) { +func (v *Handler) Value(ctx context.Context, _ reflect.Type, name string) (interface{}, bool, error) { resource := v.options.Resource if resource == nil { return nil, false, fmt.Errorf("failed to lookup handler resource: %v", name) diff --git a/service/executor/handler/options.go b/service/executor/handler/options.go index 538e130df..e1221318b 100644 --- a/service/executor/handler/options.go +++ b/service/executor/handler/options.go @@ -4,6 +4,7 @@ import ( "embed" "github.com/viant/datly/service/auth" "github.com/viant/datly/view/state" + "github.com/viant/xdatly/handler/logger" ) type options struct { @@ -11,6 +12,7 @@ type options struct { embedFS *embed.FS opts []Option auth *auth.Service + logger logger.Logger } func (o *options) Clone(opts []Option) *options { @@ -37,6 +39,12 @@ func WithTypes(types ...*state.Type) Option { } } +func WithLogger(logger logger.Logger) Option { + return func(o *options) { + o.logger = logger + } +} + func WithAuth(auth *auth.Service) Option { return func(o *options) { o.auth = auth diff --git a/service/executor/service.go b/service/executor/service.go index c89e8fc41..086a40695 100644 --- a/service/executor/service.go +++ b/service/executor/service.go @@ -4,6 +4,11 @@ import ( "context" "database/sql" "fmt" + "reflect" + "strings" + "sync/atomic" + "time" + "github.com/viant/datly/logger" expand2 "github.com/viant/datly/service/executor/expand" vsession "github.com/viant/datly/service/session" @@ -13,8 +18,6 @@ import ( "github.com/viant/sqlx/option" "github.com/viant/xdatly/handler/exec" "github.com/viant/xdatly/handler/response" - "reflect" - "time" ) type ( @@ -31,6 +34,8 @@ type ( dbSource DBSource collections map[string]*batcher.Collection logger *logger.Adapter + inserted int32 + updated int32 } DBOption func(options *DBOptions) @@ -190,6 +195,9 @@ func (e *Executor) handleUpdate(ctx context.Context, sess *dbSession, db *sql.DB options = append(options, db) updated, err := service.Exec(ctx, executable.Data, options...) + if err == nil { + atomic.AddInt32(&sess.updated, int32(updated)) + } e.logMetrics(ctx, executable.Table, "UPDATE", updated, now, err) return err } @@ -212,7 +220,7 @@ func (e *Executor) logMetrics(ctx context.Context, table string, operation strin if err != nil { metric.Error = err.Error() } - value.(*exec.Context).Metrics.Append(&metric) + value.(*exec.Context).AppendMetrics(&metric) } func (e *Executor) handleInsert(ctx context.Context, sess *dbSession, executable *expand2.Executable, db *sql.DB) error { @@ -233,6 +241,9 @@ func (e *Executor) handleInsert(ctx context.Context, sess *dbSession, executable } options = append(options, tx) inserted, _, err = service.Exec(ctx, executable.Data, options...) + if err == nil { + atomic.AddInt32(&sess.inserted, int32(inserted)) + } e.logMetrics(ctx, executable.Table, "INSERT", inserted, started, err) return err } @@ -252,6 +263,23 @@ func (e *Executor) handleInsert(ctx context.Context, sess *dbSession, executable options = append(options, option.BatchSize(batchSize)) options = append(options, e.dbOptions(db, sess)) inserted, _, err = service.Exec(ctx, executable.Data, options...) + if err == nil { + atomic.AddInt32(&sess.inserted, int32(inserted)) + } + isInvalidConnection := err != nil && strings.Contains(err.Error(), "invalid connection") + if isInvalidConnection && atomic.LoadInt32(&sess.inserted) == 0 && atomic.LoadInt32(&sess.updated) == 0 { + var dErr error + db, dErr = sess.dbSource.Db(ctx) + if dErr != nil { + return fmt.Errorf("failed after retry: %w", err) + } + sess.tx.db = db + sess.tx.tx = nil + if _, err = sess.tx.Tx(); err != nil { + return err + } + inserted, _, err = service.Exec(ctx, executable.Data, options...) + } e.logMetrics(ctx, executable.Table, "INSERT", inserted, started, err) return err } diff --git a/service/jobs/service.go b/service/jobs/service.go index c2e8ac530..caaee6a19 100644 --- a/service/jobs/service.go +++ b/service/jobs/service.go @@ -2,6 +2,7 @@ package jobs import ( "context" + "errors" "fmt" "github.com/viant/datly/service/dbms" "github.com/viant/datly/service/reader" @@ -44,7 +45,7 @@ func (s *Service) matchFailedJob(matchKey string) (*async.Job, error) { if candidate.MatchKey == matchKey { var err error if candidate.Error != nil { - err = fmt.Errorf(*candidate.Error) + err = errors.New(*candidate.Error) } else { err = fmt.Errorf("job has status %s", candidate.Status) } diff --git a/service/operator/executor.go b/service/operator/executor.go index 94ba9c1b6..94d7fab94 100644 --- a/service/operator/executor.go +++ b/service/operator/executor.go @@ -3,12 +3,13 @@ package operator import ( "context" "fmt" + "time" + "github.com/viant/datly/repository" "github.com/viant/datly/repository/contract" "github.com/viant/datly/service/executor/handler" "github.com/viant/gmetric/counter" xhandler "github.com/viant/xdatly/handler" - "time" "github.com/viant/datly/service/session" "github.com/viant/datly/view/state/kind/locator" @@ -25,6 +26,7 @@ func (s *Service) execute(ctx context.Context, aComponent *repository.Component, if aComponent.Handler != nil { aSession.SetView(aComponent.View) sessionHandler, err := anExecutor.NewHandlerSession(ctx, + handler.WithLogger(aSession.Logger()), handler.WithTypes(aComponent.Types()...), handler.WithAuth(aSession.Auth())) if err != nil { return nil, err @@ -59,6 +61,8 @@ func (s *Service) execute(ctx context.Context, aComponent *repository.Component, status := contract.StatusSuccess(executorSession.TemplateState) if err := aSession.SetState(ctx, aComponent.Output.Type.Parameters, responseState, aSession.Indirect(true, locator.WithCustom(&status), + locator.WithLogger(aSession.Logger()), + locator.WithState(statelet.Template))); err != nil { return nil, fmt.Errorf("failed to set response %w", err) } diff --git a/service/operator/reader.go b/service/operator/reader.go index bca67d847..801638e12 100644 --- a/service/operator/reader.go +++ b/service/operator/reader.go @@ -23,6 +23,10 @@ func (s *Service) runQuery(ctx context.Context, component *repository.Component, defer func() { if r := recover(); r != nil { panicMsg := fmt.Sprintf("Panic occurred: %v, Stack trace: %v", r, string(debug.Stack())) + logger := aSession.Logger() + if logger == nil { + panic(panicMsg) + } aSession.Logger().Errorc(ctx, panicMsg) err = response.NewError(http.StatusInternalServerError, "Internal server error") output = nil @@ -40,7 +44,7 @@ func (s *Service) runQuery(ctx context.Context, component *repository.Component, if err := s.updateJobStatusDone(ctx, component, handlerResponse, setting.SyncFlag, startTime); err != nil { return nil, err } - if output, err = s.finalize(ctx, handlerResponse.Output, handlerResponse.Error); err != nil { + if output, err = s.finalize(ctx, handlerResponse.Output, handlerResponse.Error, aSession); err != nil { aSession.ClearCache(component.Output.Type.Parameters) return s.HandleError(ctx, aSession, component, err) } diff --git a/service/operator/service.go b/service/operator/service.go index bf50ece0b..a5a61ecb1 100644 --- a/service/operator/service.go +++ b/service/operator/service.go @@ -6,11 +6,16 @@ import ( "encoding/json" "errors" "fmt" + "net/http" + "reflect" + "time" + "github.com/viant/afs" "github.com/viant/afs/file" "github.com/viant/datly/repository" rasync "github.com/viant/datly/repository/async" "github.com/viant/datly/repository/content" + "github.com/viant/datly/repository/contract" "github.com/viant/datly/service" "github.com/viant/datly/service/reader" "github.com/viant/datly/service/session" @@ -25,13 +30,12 @@ import ( xhandler "github.com/viant/xdatly/handler" "github.com/viant/xdatly/handler/async" "github.com/viant/xdatly/handler/exec" + xhttp "github.com/viant/xdatly/handler/http" "github.com/viant/xdatly/handler/logger" "github.com/viant/xdatly/handler/response" hstate "github.com/viant/xdatly/handler/state" + xstate "github.com/viant/xdatly/handler/state" "google.golang.org/api/googleapi" - "net/http" - "reflect" - "time" ) type Service struct { @@ -84,6 +88,7 @@ func (s *Service) HandleError(ctx context.Context, aSession *session.Session, aC func (s *Service) operate(ctx context.Context, aComponent *repository.Component, aSession *session.Session) (interface{}, error) { var err error + ctx, err = s.EnsureContext(ctx, aSession, aComponent) if err != nil { return nil, err @@ -118,13 +123,26 @@ func (s *Service) operate(ctx context.Context, aComponent *repository.Component, } } - return s.finalize(ctx, ret, err) + return s.finalize(ctx, ret, err, aSession) } return nil, response.NewError(500, fmt.Sprintf("unsupported Type %v", aComponent.Service)) } -func (s *Service) finalize(ctx context.Context, ret interface{}, err error) (interface{}, error) { +func (s *Service) finalize(ctx context.Context, ret interface{}, err error, aSession *session.Session) (interface{}, error) { + if injectorFinalizer, ok := ret.(state.InjectorFinalizer); ok { + + lookup := func(ctx context.Context, route xhttp.Route) (xstate.Injector, error) { + aComponent, err := aSession.Registry().Lookup(ctx, contract.NewPath(route.Method, route.URL)) + if err != nil { + return nil, err + } + return aSession.NewSession(aComponent), nil + } + + err = injectorFinalizer.Finalize(ctx, lookup) + return ret, err + } if err != nil { return ret, err } diff --git a/service/reader/handler/handler.go b/service/reader/handler/handler.go index 83d47d11e..6e83335e3 100644 --- a/service/reader/handler/handler.go +++ b/service/reader/handler/handler.go @@ -2,7 +2,8 @@ package handler import ( "context" - goJson "github.com/goccy/go-json" + "encoding/json" + "github.com/viant/datly/gateway/router/status" _ "github.com/viant/datly/repository/locator/async" _ "github.com/viant/datly/repository/locator/component" @@ -10,6 +11,9 @@ import ( _ "github.com/viant/datly/repository/locator/output" _ "github.com/viant/datly/service/executor/handler/locator" + "net/http" + "reflect" + reader "github.com/viant/datly/service/reader" "github.com/viant/datly/service/session" "github.com/viant/datly/utils/httputils" @@ -18,8 +22,6 @@ import ( "github.com/viant/datly/view/state/kind/locator" "github.com/viant/structology" "github.com/viant/xdatly/handler/response" - "net/http" - "reflect" ) type ( @@ -65,7 +67,9 @@ func (h *Handler) Handle(ctx context.Context, aView *view.View, aSession *sessio resultState := h.output.NewState() statelet := aSession.State().Lookup(aView) - var locatorOptions []locator.Option + var locatorOptions = []locator.Option{ + locator.WithLogger(aSession.Logger()), + } locatorOptions = append(locatorOptions, locator.WithParameterLookup(func(ctx context.Context, parameter *state.Parameter) (interface{}, bool, error) { return aSession.LookupValue(ctx, parameter, aSession.Indirect(true, locatorOptions...)) }), @@ -135,7 +139,11 @@ func (h *Handler) publishViewSummaryIfNeeded(aView *view.View, ret *Response) { if templateMeta.Kind != view.MetaKindHeader { return } - data, err := goJson.Marshal(ret.Reader.DataSummary) + var data []byte + var err error + if ret.Reader.DataSummary != nil { + data, err = json.Marshal(ret.Reader.DataSummary) + } if err != nil { ret.StatusCode = http.StatusInternalServerError ret.Status.Status = "error" @@ -153,7 +161,7 @@ func (h *Handler) publishMetricsIfNeeded(aSession *reader.Session, ret *Response if info.Executions == nil { continue } - data, err := goJson.Marshal(info) + data, err := json.Marshal(info) if err != nil { continue } diff --git a/service/reader/service.go b/service/reader/service.go index 29f1fea7d..1fda9fc9e 100644 --- a/service/reader/service.go +++ b/service/reader/service.go @@ -4,6 +4,13 @@ import ( "context" "database/sql" "fmt" + "reflect" + "strings" + "sync" + "sync/atomic" + "time" + "unsafe" + "github.com/google/uuid" "github.com/viant/datly/service/executor/expand" "github.com/viant/datly/shared" @@ -19,10 +26,6 @@ import ( "github.com/viant/xdatly/handler" "github.com/viant/xdatly/handler/exec" "github.com/viant/xdatly/handler/response" - "reflect" - "sync" - "time" - "unsafe" ) // Service represents reader service @@ -101,7 +104,7 @@ func (s *Service) afterRead(ctx context.Context, aSession *Session, collector *v onFinish(end) if value := ctx.Value(exec.ContextKey); value != nil { if exeCtx := value.(*exec.Context); exeCtx != nil { - exeCtx.Metrics.Append(metrics) + exeCtx.AppendMetrics(metrics) } } } @@ -183,6 +186,7 @@ func (s *Service) readAll(ctx context.Context, session *Session, collector *view } return } + // if onRelationalConcurrency > 1 , then only we call it concurrently concurrencyLimit := make(chan struct{}, onRelationerConcurrency) var onRelationWaitGroup sync.WaitGroup @@ -513,12 +517,25 @@ func (s *Service) queryWithHandler(ctx context.Context, session *Session, aView if session.DryRun { return []*response.SQLExecution{stats}, nil } + + retires := uint32(0) +BEGIN: reader, err := read.New(ctx, db, parametrizedSQL.SQL, collector.NewItem(), options...) + + isInvalidConnection := err != nil && strings.Contains(err.Error(), "invalid connection") + if isInvalidConnection && atomic.AddUint32(&retires, 1) < 3 { + db, err = aView.Connector.DB() + if err != nil { + return nil, fmt.Errorf("failed to connect to db: %w", err) + } + goto BEGIN + } if err != nil { stats.SetError(err) anExec, err := s.HandleSQLError(err, session, aView, parametrizedSQL, stats) return []*response.SQLExecution{anExec}, err } + defer func() { stmt := reader.Stmt() if stmt == nil { @@ -527,7 +544,17 @@ func (s *Service) queryWithHandler(ctx context.Context, session *Session, aView _ = stmt.Close() }() err = reader.QueryAll(ctx, handler, parametrizedSQL.Args...) + + isInvalidConnection = err != nil && strings.Contains(err.Error(), "invalid connection") + if isInvalidConnection && atomic.AddUint32(&retires, 1) < 3 { + db, err = aView.Connector.DB() + if err != nil { + return nil, fmt.Errorf("failed to connect to db: %w", err) + } + goto BEGIN + } end := time.Now() + aView.Logger.ReadingData(end.Sub(begin), parametrizedSQL.SQL, *readData, parametrizedSQL.Args, err) if err != nil { stats.SetError(err) diff --git a/service/reader/sql.go b/service/reader/sql.go index 33d7747ca..256cd26f3 100644 --- a/service/reader/sql.go +++ b/service/reader/sql.go @@ -3,14 +3,15 @@ package reader import ( "context" "fmt" + "strconv" + "strings" + "github.com/viant/datly/service/executor/expand" "github.com/viant/datly/service/reader/metadata" "github.com/viant/datly/shared" "github.com/viant/datly/view" "github.com/viant/datly/view/keywords" "github.com/viant/sqlx/io/read/cache" - "strconv" - "strings" ) const ( @@ -44,19 +45,36 @@ func (b *Builder) Build(ctx context.Context, opts ...BuilderOption) (*cache.Parm options := newBuilderOptions(opts...) aView := options.view statelet := options.statelet - batchData := *options.batchData + // guard against nil batchData passed by callers + var batchData view.BatchData + if options.batchData != nil { + batchData = *options.batchData + } relation := options.relation exclude := options.exclude parent := options.parent partitions := options.partition expander := options.expander + + // ensure non-nil statelet to avoid nil deref on Template usage + if statelet == nil { + statelet = view.NewStatelet() + statelet.Init(aView) + } + state, err := aView.Template.EvaluateSource(ctx, statelet.Template, parent, &batchData, expander) if err != nil { return nil, err } + if state == nil { + return nil, fmt.Errorf("failed to evaluate state for view %v, state was nil", aView.Name) + } + if state.Expanded == "" { + return nil, fmt.Errorf("failed to evaluate expanded for view %vm statelet was nil", aView.Name) + } if len(state.Filters) > 0 { - statelet.Filters = append(statelet.Filters, state.Filters...) + statelet.AppendFilters(state.Filters) } if aView.Template.IsActualTemplate() && aView.ShouldTryDiscover() { state.Expanded = metadata.EnrichWithDiscover(state.Expanded, true) @@ -323,7 +341,7 @@ func (b *Builder) updateColumnsIn(params *view.CriteriaParam, batchData *view.Ba params.ColumnsIn = sb.String() } -func (b *Builder) appendOrderBy(sb *strings.Builder, view *view.View, selector *view.Statelet) error { +func (b *Builder) appendOrderBy(sb *strings.Builder, aView *view.View, selector *view.Statelet) error { if selector.OrderBy != "" { fragment := strings.Builder{} items := strings.Split(strings.ReplaceAll(selector.OrderBy, ":", " "), ",") @@ -344,12 +362,23 @@ func (b *Builder) appendOrderBy(sb *strings.Builder, view *view.View, selector * switch strings.ToLower(sortDirection) { case "asc", "desc", "": default: - return fmt.Errorf("invalid sort direction %v for column %v at view %v", sortDirection, column, view.Name) + return fmt.Errorf("invalid sort direction %v for column %v at aView %v", sortDirection, column, aView.Name) } - col, ok := view.ColumnByName(column) + col, ok := aView.ColumnByName(column) + if !ok { + + if aView.Selector.Constraints.HasOrderByColumn(column) { + mapped := aView.Selector.Constraints.OrderByColumn[column] + col = &view.Column{ + Name: mapped, + } + ok = true + } + + } if !ok { - return fmt.Errorf("not found column %v at view %v", column, view.Name) + return fmt.Errorf("not found column %v at aView %v", column, aView.Name) } fragment.WriteString(col.Name) if sortDirection != "" { @@ -362,9 +391,9 @@ func (b *Builder) appendOrderBy(sb *strings.Builder, view *view.View, selector * return nil } - if view.Selector.OrderBy != "" { + if aView.Selector.OrderBy != "" { sb.WriteString(orderByFragment) - sb.WriteString(strings.ReplaceAll(view.Selector.OrderBy, ":", " ")) + sb.WriteString(strings.ReplaceAll(aView.Selector.OrderBy, ":", " ")) return nil } diff --git a/service/session/option.go b/service/session/option.go index 0fc7ea6a0..3568b7b35 100644 --- a/service/session/option.go +++ b/service/session/option.go @@ -2,7 +2,9 @@ package session import ( "context" + "database/sql" "embed" + "github.com/viant/datly/repository" "github.com/viant/datly/service/auth" "github.com/viant/datly/view" @@ -32,6 +34,8 @@ type ( scope string embeddedFS *embed.FS auth *auth.Service + preseedCache bool + sqlTx *sql.Tx } Option func(o *Options) @@ -45,6 +49,11 @@ func (o *Options) Registry() *repository.Registry { return o.registry } +// SqlTx returns associated SQL transaction (if any) +func (o *Options) SqlTx() *sql.Tx { + return o.sqlTx +} + func (o *Options) HasInputParameters() bool { if o.locatorOpt == nil { return false @@ -154,6 +163,20 @@ func WithAuth(auth *auth.Service) Option { } } +// WithSQLTx associates an existing SQL transaction with the session +func WithSQLTx(tx *sql.Tx) Option { + return func(s *Options) { + s.sqlTx = tx + } +} + +// WithPreseedCache controls whether NewSession should pre-seed child cache from parent (default false) +func WithPreseedCache(flag bool) Option { + return func(s *Options) { + s.preseedCache = flag + } +} + func WithComponent(component *repository.Component) Option { return func(s *Options) { s.component = component @@ -183,3 +206,9 @@ func WithRegistry(registry *repository.Registry) Option { s.registry = registry } } + +func WithLogger(logger logger.Logger) Option { + return func(s *Options) { + s.logger = logger + } +} diff --git a/service/session/selector.go b/service/session/selector.go index 895a953d2..a7febeb8b 100644 --- a/service/session/selector.go +++ b/service/session/selector.go @@ -3,13 +3,15 @@ package session import ( "context" "fmt" + "strconv" + "strings" + "github.com/viant/datly/service/session/criteria" "github.com/viant/datly/view" "github.com/viant/tagly/format/text" "github.com/viant/xdatly/codec" "github.com/viant/xdatly/handler/response" - "strconv" - "strings" + hstate "github.com/viant/xdatly/handler/state" ) func (s *Session) setQuerySelector(ctx context.Context, ns *view.NamespaceView, opts *Options) (err error) { @@ -18,6 +20,12 @@ func (s *Session) setQuerySelector(ctx context.Context, ns *view.NamespaceView, return nil } + selector := s.state.Lookup(ns.View) + + var injected *hstate.NamedQuerySelector + if opts != nil && opts.locatorOpt != nil && opts.locatorOpt.QuerySelectors != nil { + injected = opts.locatorOpt.QuerySelectors.Find(ns.View.Name) + } if err = s.populateFieldQuerySelector(ctx, ns, opts); err != nil { return response.NewParameterError(ns.View.Name, selectorParameters.FieldsParameter.Name, err) } @@ -36,13 +44,63 @@ func (s *Session) setQuerySelector(ctx context.Context, ns *view.NamespaceView, if err = s.populatePageQuerySelector(ctx, ns, opts); err != nil { return response.NewParameterError(ns.View.Name, selectorParameters.PageParameter.Name, err) } - selector := s.state.Lookup(ns.View) + + // Apply injected selector last so it takes precedence over request-derived values, + // but still validate against view selector constraints. + if injected != nil { + selector.QuerySelector = injected.QuerySelector + if err := s.applyInjectedQuerySelector(ns, selector, injected); err != nil { + return err + } + } else if selector.Page > 0 && selector.Offset == 0 { + // If selector was pre-set (e.g. from non-query sources) without an explicit page parameter, + // apply Page semantics to compute Offset/Limit. + _ = s.setPageQuerySelector(selector.Page, ns) + } if selector.Limit == 0 && selector.Offset != 0 { return fmt.Errorf("can't use offset without limit - view: %v", ns.View.Name) } return nil } +func (s *Session) applyInjectedQuerySelector(ns *view.NamespaceView, selector *view.Statelet, injected *hstate.NamedQuerySelector) error { + if injected == nil || selector == nil { + return nil + } + if len(injected.Fields) > 0 { + if err := s.setFieldsQuerySelector(injected.Fields, ns); err != nil { + return err + } + } + if injected.Limit != 0 { + if err := s.setLimitQuerySelector(injected.Limit, ns); err != nil { + return err + } + } + if injected.Offset != 0 { + if err := s.setOffsetQuerySelector(injected.Offset, ns); err != nil { + return err + } + } + if injected.OrderBy != "" { + items := strings.Split(injected.OrderBy, ",") + if err := s.setOrderByQuerySelector(items, ns); err != nil { + return err + } + } + if injected.Criteria != "" { + if err := s.setCriteriaQuerySelector(injected.Criteria, ns); err != nil { + return err + } + } + if injected.Page != 0 { + if err := s.setPageQuerySelector(injected.Page, ns); err != nil { + return err + } + } + return nil +} + func (s *Session) setQuerySettings(ctx context.Context, ns *view.NamespaceView, opts *Options) (err error) { selectorParameters := ns.View.Selector if selectorParameters == nil { @@ -154,6 +212,9 @@ func (s *Session) setOrderByQuerySelector(value interface{}, ns *view.NamespaceV continue //position based, not need to validate } + if ns.View.Selector.Constraints.HasOrderByColumn(column) { + continue + } _, ok := ns.View.ColumnByName(column) if !ok { return fmt.Errorf("not found column %v at view %v", items, ns.View.Name) @@ -201,7 +262,10 @@ func (s *Session) setLimitQuerySelector(value interface{}, ns *view.NamespaceVie return fmt.Errorf("can't use Limit on view %v", ns.View.Name) } selector := s.state.Lookup(ns.View) - limit := value.(int) + limit, err := toInt(value) + if err != nil { + return fmt.Errorf("invalid limit value: %v", err) + } if limit <= ns.View.Selector.Limit || ns.View.Selector.Limit == 0 { selector.Limit = limit } @@ -223,7 +287,19 @@ func (s *Session) setFieldsQuerySelector(value interface{}, ns *view.NamespaceVi return fmt.Errorf("can't use projection on view %v", ns.View.Name) } selector := s.state.Lookup(ns.View) - fields := value.([]string) + var fields []string + switch v := value.(type) { + case []string: + fields = v + case []interface{}: + for _, elem := range v { + text, ok := elem.(string) + if !ok { + continue + } + fields = append(fields, text) + } + } for _, field := range fields { fieldName := ns.View.CaseFormat.Format(field, text.CaseFormatUpperCamel) if err = canUseColumn(ns.View, fieldName); err != nil { @@ -270,3 +346,20 @@ func canUseColumn(aView *view.View, columnName string) error { } return nil } + +func toInt(v interface{}) (int, error) { + switch val := v.(type) { + case int: + return val, nil + case int32: + return int(val), nil + case int64: + return int(val), nil + case float64: + return int(val), nil + case float32: + return int(val), nil + default: + return 0, fmt.Errorf("unsupported type: %T", v) + } +} diff --git a/service/session/selector_injector_test.go b/service/session/selector_injector_test.go new file mode 100644 index 000000000..7f8275fb6 --- /dev/null +++ b/service/session/selector_injector_test.go @@ -0,0 +1,89 @@ +package session + +import ( + "context" + "net/http" + "reflect" + "testing" + + "github.com/viant/datly/repository" + "github.com/viant/datly/view" + vstate "github.com/viant/datly/view/state" + hstate "github.com/viant/xdatly/handler/state" +) + +func TestSessionBind_QuerySelectorOverride_PageComputesOffset(t *testing.T) { + ctx := context.Background() + + resource := view.NewResource(nil) + trueValue := true + aView := &view.View{ + Name: "v", + Mode: view.ModeQuery, + Selector: func() *view.Config { + cfg := view.QueryStateParameters.Clone() + cfg.Limit = 5 + cfg.Constraints = &view.Constraints{ + Criteria: true, + OrderBy: true, + Limit: true, + Offset: true, + Projection: true, + Page: &trueValue, + } + return cfg + }(), + } + aView.SetResource(resource) + aView.Template = &view.Template{Schema: vstate.NewSchema(reflect.TypeOf(struct{ Dummy int }{}))} + if err := aView.Template.Init(ctx, resource, aView); err != nil { + t.Fatalf("failed to init template: %v", err) + } + if err := aView.Selector.Init(ctx, resource, aView); err != nil { + t.Fatalf("failed to init selector: %v", err) + } + + component := &repository.Component{View: aView} + outputType, err := vstate.NewType( + vstate.WithSchema(vstate.NewSchema(reflect.TypeOf(struct{ X int }{}))), + vstate.WithResource(aView.Resource()), + ) + if err != nil { + t.Fatalf("failed to build component output type: %v", err) + } + component.Output.Type = *outputType + + sess := New(aView, WithComponent(component)) + var dest struct{} + + // request supplies different selector values; injected selector should take precedence + req, err := http.NewRequest(http.MethodGet, "http://127.0.0.1/?_page=1&_limit=1", nil) + if err != nil { + t.Fatalf("failed to build request: %v", err) + } + + err = sess.Bind(ctx, &dest, hstate.WithQuerySelector(&hstate.NamedQuerySelector{ + Name: "v", + QuerySelector: hstate.QuerySelector{ + Page: 2, + }, + }), hstate.WithHttpRequest(req)) + if err != nil { + t.Fatalf("Bind() error: %v", err) + } + + if err := sess.SetViewState(ctx, aView); err != nil { + t.Fatalf("SetViewState() error: %v", err) + } + + selector := sess.State().Lookup(aView) + if selector.Page != 2 { + t.Fatalf("expected Page=2, got %d", selector.Page) + } + if selector.Limit != 5 { + t.Fatalf("expected Limit=5, got %d", selector.Limit) + } + if selector.Offset != 5 { + t.Fatalf("expected Offset=5, got %d", selector.Offset) + } +} diff --git a/service/session/state.go b/service/session/state.go index 62a75f07f..5779250fc 100644 --- a/service/session/state.go +++ b/service/session/state.go @@ -13,6 +13,7 @@ import ( "github.com/pkg/errors" "github.com/viant/datly/internal/converter" + "github.com/viant/datly/repository" "github.com/viant/datly/service/auth" "github.com/viant/datly/utils/types" "github.com/viant/datly/view" @@ -42,6 +43,42 @@ type ( } ) +func (s *Session) NewSession(component *repository.Component) *Session { + ret := *s + // set component and view on the child session (do not mutate receiver) + ret.component = component + ret.Options.component = component + ret.view = component.View + if ret.locatorOpt != nil { + if _, ok := ret.locatorOpt.Views[component.View.Name]; !ok { + ret.locatorOpt.Views.Register(component.View) + } + } + + // create a fresh cache and optionally pre-populate from parent cache values + parent := s.cache + ret.cache = newCache() + if ret.Options.preseedCache && parent != nil { + parent.RWMutex.RLock() + for k, v := range parent.values { + ret.cache.values[k] = v + } + parent.RWMutex.RUnlock() + } + + // reset predicates (filters) on the child session state + if ret.Options.state != nil { + ret.Options.state.RWMutex.Lock() + for _, st := range ret.Options.state.Views { + if st != nil { + st.Filters = nil + } + } + ret.Options.state.RWMutex.Unlock() + } + return &ret +} + func (s *Session) SetView(view *view.View) { s.view = view } @@ -168,6 +205,7 @@ func (s *Session) viewLookupOptions(aView *view.View, parameters state.NamedPara if !opts.HasInputParameters() { result = append(result, locator.WithInputParameters(parameters)) } + result = append(result, locator.WithLogger(s.logger)) result = append(result, locator.WithReadInto(s.ReadInto)) viewState := s.state.Lookup(aView) result = append(result, locator.WithState(viewState.Template)) @@ -202,7 +240,7 @@ func (s *Session) setTemplateState(ctx context.Context, aView *view.View, opts * aState := s.state.Lookup(aView) if template := aView.Template; template != nil { stateType := template.StateType() - if stateType.IsDefined() { + if stateType != nil && stateType.IsDefined() { templateState := aState.Template templateState.EnsureMarker() err := s.SetState(ctx, template.Parameters, templateState, opts) @@ -247,6 +285,52 @@ func (s *Session) populateParameterInBackground(ctx context.Context, parameter * } } +// The function below causes SIGBUS when template parameters are rebound. +//E.g. a predicate builder velty expression is located in an embedded SQL, outside main DQL +//func (s *Session) populateParameter(ctx context.Context, parameter *state.Parameter, aState *structology.State, options *Options) error { +// value, has, err := s.LookupValue(ctx, parameter, options) +// if err != nil { +// return err +// } +// if !has { +// if parameter.IsRequired() { +// return fmt.Errorf("parameter %v is required", parameter.Name) +// } +// return nil +// } +// +// parameterSelector := parameter.Selector() +// if options.indirectState || parameterSelector == nil { //p +// parameterSelector, err = aState.Selector(parameter.Name) +// if parameterSelector == nil { +// switch parameter.In.Kind { +// case state.KindConst: +// return nil +// } +// } +// if err != nil { +// return err +// } +// } +// +// if value, err = s.ensureValidValue(value, parameter, parameterSelector, options); err != nil { +// return err +// } +// err = parameterSelector.SetValue(aState.Pointer(), value) +// +// //ensure last written can be shared +// if err == nil { +// +// switch parameterSelector.Type().Kind() { +// case reflect.Ptr: +// if parameter.Schema.Type() == parameterSelector.Type() { +// s.cache.put(parameter, parameterSelector.Value(aState.Pointer())) +// } +// } +// } +// return err +//} + func (s *Session) populateParameter(ctx context.Context, parameter *state.Parameter, aState *structology.State, options *Options) error { value, has, err := s.LookupValue(ctx, parameter, options) if err != nil { @@ -258,29 +342,25 @@ func (s *Session) populateParameter(ctx context.Context, parameter *state.Parame } return nil } - parameterSelector := parameter.Selector() - if options.indirectState || parameterSelector == nil { //p - parameterSelector, err = aState.Selector(parameter.Name) - if parameterSelector == nil && parameter.In.Kind == state.KindConst { // TODO do we really need it? - return nil - } - if err != nil { - return err - } + + // Resolve selector strictly from the state's layout + // Treat "not found" as a no-op (skip), since this view doesn't declare that parameter. + parameterSelector, err := aState.Selector(parameter.Name) + if err != nil || parameterSelector == nil { + return nil } + if value, err = s.ensureValidValue(value, parameter, parameterSelector, options); err != nil { return err } - err = parameterSelector.SetValue(aState.Pointer(), value) + if err = parameterSelector.SetValue(aState.Pointer(), value); err != nil { + return err + } - //ensure last written can be shared - if err == nil { - switch parameterSelector.Type().Kind() { - case reflect.Ptr: - s.cache.put(parameter, parameterSelector.Value(aState.Pointer())) - } + if parameterSelector.Type().Kind() == reflect.Ptr { + s.cache.put(parameter, parameterSelector.Value(aState.Pointer())) } - return err + return nil } func (s *Session) canRead(ctx context.Context, parameter *state.Parameter, opts *Options) (bool, error) { @@ -346,16 +426,22 @@ func (s *Session) ensureValidValue(value interface{}, parameter *state.Parameter if valueType.Elem().Kind() == reflect.Struct && parameter.Schema.Type().Kind() == reflect.Slice { if parameter.Schema.CompType() == valueType { sliceValuePtr := reflect.New(parameterType) + + if isNil(value) { + empty := reflect.MakeSlice(parameterType, 0, 0) + sliceValuePtr.Elem().Set(empty) + return sliceValuePtr.Interface(), nil // []T{} + } + sliceValue := reflect.MakeSlice(parameterType, 1, 1) sliceValuePtr.Elem().Set(sliceValue) sliceValue.Index(0).Set(reflect.ValueOf(value)) - return sliceValuePtr.Interface(), nil + return sliceValuePtr.Interface(), nil // []T{value}` } } case reflect.Slice: - ptr := xunsafe.AsPointer(value) - slice := parameter.Schema.Slice() - sliceLen := slice.Len(ptr) + rSlice := reflect.ValueOf(value) + sliceLen := rSlice.Len() if errorMessage := validateSliceParameter(parameter, sliceLen); errorMessage != "" { return nil, errors.New(errorMessage) } @@ -366,11 +452,45 @@ func (s *Session) ensureValidValue(value interface{}, parameter *state.Parameter default: switch sliceLen { case 0: - value = reflect.New(parameter.OutputType().Elem()).Elem().Interface() + switch outputType.Kind() { + case reflect.Ptr: + value = reflect.New(outputType.Elem()).Elem().Interface() + case reflect.Struct: + value = reflect.New(outputType).Elem().Interface() + default: + value = reflect.New(outputType).Elem().Interface() + } valueType = reflect.TypeOf(value) case 1: - value = slice.ValuePointerAt(ptr, 0) - valueType = reflect.TypeOf(value) + elem := rSlice.Index(0) + rawType := elem.Type() + if rawType.Kind() == reflect.Ptr { + rawType = rawType.Elem() + } + if rawType.Kind() == reflect.Interface { + rawType = rawType.Elem() + } + + if elem.Kind() == reflect.Interface && !elem.IsNil() { + elem = elem.Elem() + } + if rawType.Kind() != reflect.Struct { + value = elem.Interface() + valueType = reflect.TypeOf(value) + break + } + if elem.Kind() == reflect.Ptr { + value = elem.Interface() + valueType = elem.Type() + break + } + if elem.CanAddr() { + value = elem.Addr().Interface() + valueType = elem.Addr().Type() + break + } + value = elem.Interface() + valueType = elem.Type() default: return nil, fmt.Errorf("parameter %v return more than one value, len: %v rows ", parameter.Name, sliceLen) } @@ -388,53 +508,55 @@ func (s *Session) ensureValidValue(value interface{}, parameter *state.Parameter } if parameter.Schema.IsStruct() && !(valueType == selector.Type() || valueType.ConvertibleTo(selector.Type()) || valueType.AssignableTo(selector.Type())) { - - rawSelectorType := selector.Type() - isSelectorPtr := false - if rawSelectorType.Kind() == reflect.Ptr { - rawSelectorType = rawSelectorType.Elem() - isSelectorPtr = true + destType := selector.Type() + rawDestType := destType + destIsPtr := false + if rawDestType.Kind() == reflect.Ptr { + rawDestType = rawDestType.Elem() + destIsPtr = true } - isValuePtr := false - rawValueType := valueType - if rawValueType.Kind() == reflect.Ptr { - rawValueType = valueType.Elem() - isValuePtr = true + + rawSrcType := valueType + srcIsPtr := false + if rawSrcType.Kind() == reflect.Ptr { + rawSrcType = rawSrcType.Elem() + srcIsPtr = true } - if rawSelectorType.Kind() == reflect.Struct && isSelectorPtr { - if rawValueType.ConvertibleTo(rawSelectorType) { - ptrValue := reflect.ValueOf(value) - if isValuePtr && ptrValue.IsNil() { + if rawDestType.Kind() == reflect.Struct && rawSrcType.Kind() == reflect.Struct && rawSrcType.ConvertibleTo(rawDestType) { + srcValue := reflect.ValueOf(value) + if srcIsPtr { + if srcValue.IsNil() { return nil, nil } - var destValue reflect.Value - if isValuePtr { - destValue = ptrValue.Elem().Convert(rawSelectorType) - } else { - destValue = ptrValue.Convert(rawSelectorType) - } - if isSelectorPtr { - destPtrType := reflect.New(valueType) - destPtrType.Elem().Set(destValue) - return destPtrType.Interface(), nil - } else { - return destValue.Interface(), nil - } + srcValue = srcValue.Elem() } + converted := srcValue.Convert(rawDestType) + if destIsPtr { + out := reflect.New(rawDestType) + out.Elem().Set(converted) + return out.Interface(), nil + } + return converted.Interface(), nil } if options.shallReportNotAssignable() { - //if !ensureAssignable(parameter.Name, selector.Type(), valueType) { - fmt.Printf("parameter %v is not directly assignable from %s:(%s)\nsrc:%s \ndst:%s\n", parameter.Name, parameter.In.Kind, parameter.In.Name, valueType.String(), selector.Type().String()) - //} + fmt.Printf("parameter %v is not directly assignable from %s:(%s)\nsrc:%s \ndst:%s\n", parameter.Name, parameter.In.Kind, parameter.In.Name, valueType.String(), destType.String()) } - reflectValue := reflect.New(valueType) //TODO replace with fast xreflect copy - valuePtr := reflectValue.Interface() + var target reflect.Value + if destIsPtr { + target = reflect.New(rawDestType) // *T where destType is *T + } else { + target = reflect.New(destType) // *T where destType is T + } if data, err := json.Marshal(value); err == nil { - if err = json.Unmarshal(data, valuePtr); err == nil { - value = reflectValue.Elem().Interface() + if err = json.Unmarshal(data, target.Interface()); err == nil { + if destIsPtr { + value = target.Interface() + } else { + value = target.Elem().Interface() + } } } } @@ -519,8 +641,8 @@ func (s *Session) lookupFirstValue(ctx context.Context, parameters []*state.Para } func (s *Session) LookupValue(ctx context.Context, parameter *state.Parameter, opts *Options) (value interface{}, has bool, err error) { - - if value, has, err = s.lookupValue(ctx, parameter, opts); err != nil { + value, has, err = s.lookupValue(ctx, parameter, opts) + if err != nil { err = response.NewParameterError("", parameter.Name, err, response.WithObject(value), response.WithErrorStatusCode(parameter.ErrorStatusCode)) } return value, has, err @@ -573,7 +695,8 @@ func (s *Session) lookupValue(ctx context.Context, parameter *state.Parameter, o if err != nil { return nil, false, fmt.Errorf("failed to locate parameter: %v, %w", parameter.Name, err) } - if value, has, err = parameterLocator.Value(ctx, parameter.In.Name); err != nil { + + if value, has, err = parameterLocator.Value(ctx, parameter.OutputType(), parameter.In.Name); err != nil { return nil, false, err } if parameter.In.Kind == state.KindConst && !has { //if parameter is const and has no value, use default value @@ -588,7 +711,7 @@ func (s *Session) lookupValue(ctx context.Context, parameter *state.Parameter, o if err != nil { return nil, false, fmt.Errorf("failed to locate parameter: %v, %w", baseParameter.Name, err) } - if value, has, err = parameterLocator.Value(ctx, baseParameter.In.Name); err != nil { + if value, has, err = parameterLocator.Value(ctx, baseParameter.OutputType(), baseParameter.In.Name); err != nil { return nil, false, err } } @@ -610,6 +733,13 @@ func (s *Session) adjustAndCache(ctx context.Context, parameter *state.Parameter return nil, false, err } if parameter.Output != nil { + // Defensive: ensure codec is initialized before Transform. + if !parameter.Output.Initialized() { + // Initialize using session resource and current parameter input type. + if initErr := parameter.Output.Init(s.resource, parameter.Schema.Type()); initErr != nil { + return nil, false, initErr + } + } transformed, err := parameter.Output.Transform(ctx, value, opts.codecOptions...) if err != nil { return nil, false, fmt.Errorf("failed to transform %s with %s: %v, %w", parameter.Name, parameter.Output.Name, value, err) @@ -700,7 +830,42 @@ func New(aView *view.View, opts ...Option) *Session { return ret } -func (s *Session) LoadState(parameters state.Parameters, aState interface{}) error { +type loadStateOptions struct { + skipKind map[state.Kind]bool + hasSkipKind bool + useHasMarker bool + fallbackOnValue bool +} + +type LoadStateOption func(o *loadStateOptions) + +func WithHasMarker() LoadStateOption { + return func(o *loadStateOptions) { + o.useHasMarker = true + } +} + +// WithValuePresenceFallback treats non-zero values as present when no Has marker is available. +// This is opt-in to avoid changing behavior for existing inputs that intentionally omit markers. +func WithValuePresenceFallback() LoadStateOption { + return func(o *loadStateOptions) { + o.fallbackOnValue = true + } +} +func WithLoadStateSkipKind(kinds ...state.Kind) LoadStateOption { + return func(o *loadStateOptions) { + for _, kind := range kinds { + o.skipKind[kind] = true + } + } +} + +func (s *Session) LoadState(parameters state.Parameters, aState interface{}, opts ...LoadStateOption) error { + options := &loadStateOptions{skipKind: map[state.Kind]bool{}} + for _, opt := range opts { + opt(options) + } + options.hasSkipKind = len(options.skipKind) > 0 rType := reflect.TypeOf(aState) sType := structology.NewStateType(rType, structology.WithCustomizedNames(func(name string, tag reflect.StructTag) []string { stateTag, _ := tags.ParseStateTags(tag, nil) @@ -711,29 +876,46 @@ func (s *Session) LoadState(parameters state.Parameters, aState interface{}) err })) inputState := sType.WithValue(aState) ptr := xunsafe.AsPointer(aState) + // Use presence markers only if enabled and supported by the input state + hasMarker := options.useHasMarker && inputState.HasMarker() for _, parameter := range parameters { + if parameter.Scope != "" { continue } + + if options.hasSkipKind && options.skipKind[parameter.In.Kind] { + continue + } + + // Only warm cache for cacheable parameters; LookupValue only reads cache when cacheable + if !parameter.IsCacheable() { + continue + } selector, _ := inputState.Selector(parameter.Name) if selector == nil { continue } - if !selector.Has(ptr) { + // Only use selector.Has when input supports presence markers + if hasMarker && !selector.Has(ptr) { continue } + value := selector.Value(ptr) + if !hasMarker && options.fallbackOnValue && isZeroValue(value) { + continue + } switch parameter.In.Kind { case state.KindView, state.KindParam, state.KindState: if value == nil { - return nil + continue } rType := parameter.OutputType() if rType.Kind() == reflect.Ptr { ptr := (*unsafe.Pointer)(xunsafe.AsPointer(value)) if ptr == nil || *ptr == nil { - return nil + continue } } } @@ -743,10 +925,28 @@ func (s *Session) LoadState(parameters state.Parameters, aState interface{}) err return nil } +func isZeroValue(value interface{}) bool { + if value == nil { + return true + } + v := reflect.ValueOf(value) + for v.Kind() == reflect.Interface || v.Kind() == reflect.Ptr { + if v.IsNil() { + return true + } + v = v.Elem() + } + switch v.Kind() { + case reflect.Slice, reflect.Map, reflect.Array: + return v.Len() == 0 + } + return v.IsZero() +} + func (s *Session) handleParameterError(parameter *state.Parameter, err error, errors *response.Errors) { if parameter.ErrorMessage != "" && err != nil { msg := strings.ReplaceAll(parameter.ErrorMessage, "${error}", err.Error()) - err = fmt.Errorf(msg) + err = fmt.Errorf("%s", msg) } if pErr, ok := err.(*response.Error); ok { pErr.Code = parameter.ErrorStatusCode diff --git a/service/session/state_test.go b/service/session/state_test.go new file mode 100644 index 000000000..497d6aca6 --- /dev/null +++ b/service/session/state_test.go @@ -0,0 +1,291 @@ +package session + +import ( + "reflect" + "testing" + + "github.com/viant/datly/view/state" + "github.com/viant/structology" +) + +func TestSessionEnsureValidValue_Transitions(t *testing.T) { + type T struct { + A *int + B *int + } + + inlineStructSwapped := reflect.StructOf([]reflect.StructField{ + // Deliberately swap field order vs T to ensure the types are not convertible. + {Name: "B", Type: reflect.TypeOf((*int)(nil))}, + {Name: "A", Type: reflect.TypeOf((*int)(nil))}, + }) + inlinePtrType := reflect.PtrTo(inlineStructSwapped) + + newSelector := func(t *testing.T, paramType reflect.Type) *structology.Selector { + t.Helper() + stateStruct := reflect.StructOf([]reflect.StructField{ + {Name: "Param", Type: paramType}, + }) + stateType := structology.NewStateType(stateStruct) + selector := stateType.Lookup("Param") + if selector == nil { + t.Fatalf("failed to lookup selector Param") + } + return selector + } + + intPtrType := reflect.TypeOf((*int)(nil)) + + ttPtrType := reflect.TypeOf((*T)(nil)) + sliceOfTTPtrType := reflect.SliceOf(ttPtrType) + ptrToSliceOfTTPtrType := reflect.PtrTo(sliceOfTTPtrType) + intType := reflect.TypeOf(int(0)) + sliceOfIntType := reflect.SliceOf(intType) + ttType := reflect.TypeOf(T{}) + + boolPtr := func(v bool) *bool { return &v } + + cases := []struct { + name string + schemaType reflect.Type + selectorType reflect.Type + required *bool + value interface{} + wantType reflect.Type + wantErr bool + check func(t *testing.T, got interface{}) + }{ + { + name: "nil-value_ptr-schema_returns-typed-nil", + schemaType: intPtrType, + selectorType: intPtrType, + value: nil, + wantType: intPtrType, + check: func(t *testing.T, got interface{}) { + t.Helper() + if !reflect.ValueOf(got).IsNil() { + t.Fatalf("expected nil pointer, got %v", got) + } + }, + }, + { + name: "nil-value_slice-schema_returns-nil-slice", + schemaType: sliceOfIntType, + selectorType: sliceOfIntType, + value: nil, + wantType: sliceOfIntType, + check: func(t *testing.T, got interface{}) { + t.Helper() + if !reflect.ValueOf(got).IsNil() { + t.Fatalf("expected nil slice, got %v", got) + } + }, + }, + { + name: "ptr-struct_to_ptr-to-slice-wraps-single", + schemaType: sliceOfTTPtrType, + selectorType: ptrToSliceOfTTPtrType, + value: func() interface{} { + a := 10 + b := 20 + return &T{A: &a, B: &b} + }(), + wantType: ptrToSliceOfTTPtrType, + check: func(t *testing.T, got interface{}) { + t.Helper() + gotSlicePtr := reflect.ValueOf(got) + if gotSlicePtr.IsNil() { + t.Fatalf("expected non-nil pointer to slice") + } + gotSlice := gotSlicePtr.Elem() + if gotSlice.Len() != 1 { + t.Fatalf("expected len=1, got %d", gotSlice.Len()) + } + if gotSlice.Index(0).IsNil() { + t.Fatalf("expected element 0 to be non-nil") + } + }, + }, + { + name: "ptr-struct-nil_to_ptr-to-slice-wraps-empty", + schemaType: sliceOfTTPtrType, + selectorType: ptrToSliceOfTTPtrType, + value: (*T)(nil), + wantType: ptrToSliceOfTTPtrType, + check: func(t *testing.T, got interface{}) { + t.Helper() + gotSlicePtr := reflect.ValueOf(got) + if gotSlicePtr.IsNil() { + t.Fatalf("expected non-nil pointer to slice") + } + gotSlice := gotSlicePtr.Elem() + if gotSlice.Len() != 0 { + t.Fatalf("expected len=0, got %d", gotSlice.Len()) + } + }, + }, + { + name: "slice-to-scalar_len0_required_errors", + schemaType: ttPtrType, + selectorType: ttPtrType, + required: boolPtr(true), + value: []*T{}, + wantErr: true, + }, + { + name: "slice-to-scalar_len0_not-required_returns-zero", + schemaType: ttPtrType, + selectorType: ttPtrType, + value: []*T{}, + wantType: ttPtrType, + check: func(t *testing.T, got interface{}) { + t.Helper() + if reflect.ValueOf(got).IsNil() { + t.Fatalf("expected non-nil *T") + } + }, + }, + { + name: "slice-of-int_len1_to-int", + schemaType: intType, + selectorType: intType, + value: []int{7}, + wantType: intType, + check: func(t *testing.T, got interface{}) { + t.Helper() + if got.(int) != 7 { + t.Fatalf("expected 7, got %v", got) + } + }, + }, + { + name: "slice-of-int_len2_to-int_errors", + schemaType: intType, + selectorType: intType, + value: []int{1, 2}, + wantErr: true, + }, + { + name: "ptr-required_nil_errors", + schemaType: ttPtrType, + selectorType: ttPtrType, + required: boolPtr(true), + value: (*T)(nil), + wantErr: true, + }, + { + name: "ptr-value_to-struct-selector_derefs", + schemaType: ttType, + selectorType: ttType, + value: func() interface{} { + a := 3 + b := 4 + return &T{A: &a, B: &b} + }(), + wantType: ttType, + check: func(t *testing.T, got interface{}) { + t.Helper() + gotT := got.(T) + if gotT.A == nil || gotT.B == nil { + t.Fatalf("expected non-nil fields") + } + if *gotT.A != 3 || *gotT.B != 4 { + t.Fatalf("unexpected values: %+v", gotT) + } + }, + }, + { + name: "struct-value_to-ptr-selector_allocates", + schemaType: ttPtrType, + selectorType: ttPtrType, + value: func() interface{} { + a := 5 + b := 6 + return T{A: &a, B: &b} + }(), + wantType: ttPtrType, + check: func(t *testing.T, got interface{}) { + t.Helper() + gotPtr := got.(*T) + if gotPtr == nil || gotPtr.A == nil || gotPtr.B == nil { + t.Fatalf("expected non-nil *T with non-nil fields") + } + if *gotPtr.A != 5 || *gotPtr.B != 6 { + t.Fatalf("unexpected values: %+v", *gotPtr) + } + }, + }, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + parameter := &state.Parameter{ + Name: "Param", + In: state.NewState("Param"), + Schema: state.NewSchema(tc.schemaType), + Required: tc.required, + } + + selector := newSelector(t, tc.selectorType) + sess := &Session{} + opts := NewOptions(WithReportNotAssignable(false)) + + got, err := sess.ensureValidValue(tc.value, parameter, selector, opts) + if (err != nil) != tc.wantErr { + t.Fatalf("error=%v, wantErr=%v", err, tc.wantErr) + } + if tc.wantErr { + return + } + if tc.wantType != nil && reflect.TypeOf(got) != tc.wantType { + t.Fatalf("expected %v, got %T", tc.wantType, got) + } + if tc.check != nil { + tc.check(t, got) + } + }) + } + + t.Run("slice-of-named-ptr_to-inline-ptr_allocates-and-copies_details", func(t *testing.T) { + a := 1 + b := 2 + original := &T{A: &a, B: &b} + input := []*T{original} + + parameter := &state.Parameter{ + Name: "Param", + In: state.NewState("Param"), + Schema: state.NewSchema(inlinePtrType), + } + selector := newSelector(t, inlinePtrType) + sess := &Session{} + opts := NewOptions(WithReportNotAssignable(false)) + + got, err := sess.ensureValidValue(input, parameter, selector, opts) + if err != nil { + t.Fatalf("ensureValidValue error: %v", err) + } + if reflect.TypeOf(got) != inlinePtrType { + t.Fatalf("expected %v, got %T", inlinePtrType, got) + } + + gotPtr := reflect.ValueOf(got).Pointer() + origPtr := reflect.ValueOf(original).Pointer() + if gotPtr == origPtr { + t.Fatalf("expected ensureValidValue to allocate/copy into %v; got aliases original *T pointer %x", inlinePtrType, gotPtr) + } + + gotValue := reflect.ValueOf(got).Elem() + gotA := gotValue.FieldByName("A") + gotB := gotValue.FieldByName("B") + if gotA.IsNil() || gotB.IsNil() { + t.Fatalf("expected A and B to be non-nil") + } + if gotA.Elem().Int() != int64(*original.A) { + t.Fatalf("expected A=%d, got %d", *original.A, gotA.Elem().Int()) + } + if gotB.Elem().Int() != int64(*original.B) { + t.Fatalf("expected B=%d, got %d", *original.B, gotB.Elem().Int()) + } + }) +} diff --git a/service/session/stater.go b/service/session/stater.go index 2c75bcbe9..5b0b0d6ab 100644 --- a/service/session/stater.go +++ b/service/session/stater.go @@ -2,11 +2,19 @@ package session import ( "context" + "fmt" + "net/http" + "reflect" + "runtime/debug" + + "embed" + "github.com/viant/datly/utils/types" + "github.com/viant/datly/view" "github.com/viant/datly/view/state" "github.com/viant/datly/view/state/kind/locator" + "github.com/viant/xdatly/handler/response" hstate "github.com/viant/xdatly/handler/state" - "reflect" ) func (s *Session) ValuesOf(ctx context.Context, any interface{}) (map[string]interface{}, error) { @@ -35,15 +43,48 @@ func (s *Session) Into(ctx context.Context, dest interface{}, opts ...hstate.Opt } func (s *Session) Bind(ctx context.Context, dest interface{}, opts ...hstate.Option) (err error) { + defer func() { + if r := recover(); r != nil { + panicMsg := fmt.Sprintf("Panic occurred: %v, Stack trace: %v", r, string(debug.Stack())) + logger := s.Logger() + if logger == nil { + panic(panicMsg) + } + s.Logger().Errorc(ctx, panicMsg) + err = response.NewError(http.StatusInternalServerError, "Internal server error") + } + }() + destType := reflect.TypeOf(dest) sType := types.EnsureStruct(destType) stateType, ok := s.Types.Lookup(sType) - if !ok { - if stateType, err = state.NewType( - state.WithSchema(state.NewSchema(destType)), - state.WithResource(s.resource), - ); err != nil { - return err + + var embedFs *embed.FS + if embedder, ok := dest.(state.Embedder); ok { + embedFs = embedder.EmbedFS() + } + + if !ok && s.component != nil { + + if s.component.Input.Type.Type() != nil { + if destType == s.component.Input.Type.Type().Type() { + stateType = &s.component.Input.Type + } + } + if s.component.Output.Type.Type() != nil { + if destType == s.component.Output.Type.Type().Type() { + stateType = &s.component.Output.Type + } + } + + if stateType == nil { + if stateType, err = state.NewType( + state.WithSchema(state.NewSchema(destType)), + state.WithResource(s.resource), + state.WithFS(embedFs), + ); err != nil { + return err + } } s.Types.Put(stateType) } @@ -52,9 +93,11 @@ func (s *Session) Bind(ctx context.Context, dest interface{}, opts ...hstate.Opt } hOptions := hstate.NewOptions(opts...) - aState := stateType.Type().WithValue(dest) - var stateOptions []locator.Option + aState := stateType.Type().WithValue(dest) + var stateOptions = []locator.Option{ + locator.WithLogger(s.logger), + } var locatorsToRemove = []state.Kind{state.KindComponent} if hOptions.Constants() != nil { stateOptions = append(stateOptions, locator.WithConstants(hOptions.Constants())) @@ -76,50 +119,140 @@ func (s *Session) Bind(ctx context.Context, dest interface{}, opts ...hstate.Opt locatorsToRemove = append(locatorsToRemove, httpKinds...) } if hOptions.Query() != nil { - stateOptions = append(stateOptions, locator.WithQuery(hOptions.Query())) + queryOpt := locator.WithQuery(hOptions.Query()) + stateOptions = append(stateOptions, queryOpt) + s.locatorOptions = append(s.locatorOptions, queryOpt) locatorsToRemove = append(locatorsToRemove, httpKinds...) } if len(hOptions.PathParameters()) > 0 { - stateOptions = append(stateOptions, locator.WithPathParameters(hOptions.PathParameters())) + pathOpt := locator.WithPathParameters(hOptions.PathParameters()) + stateOptions = append(stateOptions, pathOpt) + s.locatorOptions = append(s.locatorOptions, pathOpt) locatorsToRemove = append(locatorsToRemove, httpKinds...) } if hOptions.HttpRequest() != nil { - stateOptions = append(stateOptions, locator.WithRequest(hOptions.HttpRequest())) + requestOpt := locator.WithRequest(hOptions.HttpRequest()) + stateOptions = append(stateOptions, requestOpt) + s.locatorOptions = append(s.locatorOptions, requestOpt) locatorsToRemove = append(locatorsToRemove, httpKinds...) } + if selectors := hOptions.QuerySelectors(); len(selectors) > 0 { + selectorOpt := locator.WithQuerySelectors(selectors) + stateOptions = append(stateOptions, selectorOpt) + s.locatorOptions = append(s.locatorOptions, selectorOpt) + } + // Keep parsed locator options in sync with any dynamic additions made via injector.Bind. + if len(s.locatorOptions) > 0 { + s.locatorOpt = locator.NewOptions(s.locatorOptions) + s.kindLocator = locator.NewKindsLocator(nil, s.locatorOptions...) + } s.kindLocator.RemoveLocators(locatorsToRemove...) if s.view != nil { viewOptions := s.ViewOptions(s.view, WithLocatorOptions()) stateOptions = append(viewOptions.kindLocator.Options(), stateOptions...) } - if s.component != nil && s.component.Contract.Output.Type.Type().Type() == destType { - return s.handleComponentpOutputType(ctx, dest, stateOptions) + if err = s.handleInputState(ctx, hOptions, embedFs); err != nil { + return err + } + + if s.component != nil { + componentOutputType := types.EnsureStruct(s.component.Contract.Output.Type.Type().Type()) + if componentOutputType == types.EnsureStruct(destType) { + return s.handleComponentOutputType(ctx, dest, stateOptions) + } } options := s.Indirect(true, stateOptions...) options.scope = hOptions.Scope() + if err = s.SetState(ctx, stateType.Parameters, aState, options); err != nil { return err } + if initializer, ok := dest.(state.Initializer); ok { err = initializer.Init(ctx) } return err } -func (s *Session) handleComponentpOutputType(ctx context.Context, dest interface{}, stateOptions []locator.Option) error { +func (s *Session) handleInputState(ctx context.Context, hOptions *hstate.Options, embedFs *embed.FS) error { + // Handle WithInput: preload cache from provided input data + input := hOptions.Input() + if input == nil { + return nil + } + var parameters state.Parameters + var inputType *state.Type + // If input type matches component input type, reuse component parameters + if s.component != nil && s.component.Input.Type.Type() != nil && s.component.Input.Type.Type().Type() != nil { + compInType := s.component.Input.Type.Type().Type() + inType := reflect.TypeOf(input) + if inType != nil && compInType != nil && types.EnsureStruct(inType) == types.EnsureStruct(compInType) { + parameters = s.component.Input.Type.Parameters + inputType = &s.component.Input.Type + } + } + // Otherwise, derive parameters from input type + if len(parameters) == 0 { + inType := reflect.TypeOf(input) + aType, e := state.NewType( + state.WithFS(embedFs), + state.WithSchema(state.NewSchema(inType)), + state.WithResource(s.resource), + ) + if e != nil { + return e + } + if e = aType.Init(); e != nil { + return e + } + inputType = aType + for _, p := range aType.Parameters { + p.Init(ctx, s.view.Resource()) + } + parameters = aType.Parameters + } + + var skipOption []LoadStateOption + skipOption = append(skipOption, WithHasMarker()) + if s.view.Mode != view.ModeQuery { + //this is for patch component only (in the future we may pass it to caller when call Bind + skipOption = append(skipOption, WithLoadStateSkipKind(state.KindView, state.KindParam)) + } + if e := s.LoadState(parameters, input, skipOption...); e != nil { + return e + } + if s.view.Mode == view.ModeQuery { + inputState := inputType.Type().WithValue(input) + options := s.Options.Indirect(true) + if err := s.SetState(ctx, parameters, inputState, options); err != nil { + return err + } + _ = s.SetViewState(ctx, s.view) + } + return nil +} + +func (s *Session) handleComponentOutputType(ctx context.Context, dest interface{}, stateOptions []locator.Option) error { sessionOpt := s.Options s.Options = *s.Indirect(true, stateOptions...) destValue, err := s.operate(ctx, s, s.component) - s.Options = sessionOpt - - if destValue != nil { - reflect.ValueOf(dest).Elem().Set(reflect.ValueOf(destValue).Elem()) + destPtr := reflect.ValueOf(dest) + if err != nil && destValue == nil { + if errorSetter, ok := dest.(response.StatusSetter); ok { + errorSetter.SetError(err) + return nil + } + return err } + s.Options = sessionOpt + reflectDestValue := reflect.ValueOf(destValue) - if err != nil { - return err + if reflectDestValue.Kind() == reflect.Ptr { + destPtr.Elem().Set(reflectDestValue.Elem()) + } else { + destPtr.Elem().Set(reflectDestValue) } return nil } diff --git a/shared/args.go b/shared/args.go index 0fcf59a08..198ddf3a6 100644 --- a/shared/args.go +++ b/shared/args.go @@ -1,10 +1,88 @@ package shared -import "strings" - func EnsureArgs(query string, args *[]interface{}) { - parameterCount := strings.Count(query, "?") + parameterCount := countPlaceholders(query) for i := len(*args); i < parameterCount; i++ { //ensure parameters *args = append(*args, "") } } + +func countPlaceholders(query string) int { + count := 0 + inSingle := false + inDouble := false + inBacktick := false + inLineComment := false + inBlockComment := false + + for i := 0; i < len(query); i++ { + ch := query[i] + + if inLineComment { + if ch == '\n' || ch == '\r' { + inLineComment = false + } + continue + } + if inBlockComment { + if ch == '*' && i+1 < len(query) && query[i+1] == '/' { + inBlockComment = false + i++ + } + continue + } + if inSingle { + if ch == '\\' { + if i+1 < len(query) { + i++ + } + continue + } + if ch == '\'' { + inSingle = false + } + continue + } + if inDouble { + if ch == '\\' { + if i+1 < len(query) { + i++ + } + continue + } + if ch == '"' { + inDouble = false + } + continue + } + if inBacktick { + if ch == '`' { + inBacktick = false + } + continue + } + + if ch == '-' && i+1 < len(query) && query[i+1] == '-' { + inLineComment = true + i++ + continue + } + if ch == '/' && i+1 < len(query) && query[i+1] == '*' { + inBlockComment = true + i++ + continue + } + + switch ch { + case '\'': + inSingle = true + case '"': + inDouble = true + case '`': + inBacktick = true + case '?': + count++ + } + } + return count +} diff --git a/shared/args_test.go b/shared/args_test.go new file mode 100644 index 000000000..8b0f4dd64 --- /dev/null +++ b/shared/args_test.go @@ -0,0 +1,33 @@ +package shared + +import "testing" + +func TestCountPlaceholders(t *testing.T) { + testCases := []struct { + name string + query string + expect int + }{ + { + name: "simple placeholders", + query: "SELECT * FROM t WHERE a = ? AND b = ?", + expect: 2, + }, + { + name: "ignore single quoted regex", + query: "SELECT REGEXP_REPLACE(col, r'^(?:https?://)?(?:www\\.)?', '') FROM t WHERE a = ?", + expect: 1, + }, + { + name: "ignore comments and quoted text", + query: "SELECT '?' -- ?\nFROM t /* ? */ WHERE x = ?", + expect: 1, + }, + } + + for _, testCase := range testCases { + if actual := countPlaceholders(testCase.query); actual != testCase.expect { + t.Fatalf("%s: expected %d placeholders, got %d", testCase.name, testCase.expect, actual) + } + } +} diff --git a/shared/combine.go b/shared/combine.go index b63329fa2..67cbbecca 100644 --- a/shared/combine.go +++ b/shared/combine.go @@ -7,7 +7,7 @@ func CombineErrors(header string, errors []error) error { return nil } - outputErr := fmt.Errorf(header) + outputErr := fmt.Errorf("%s", header) for _, err := range errors { outputErr = fmt.Errorf("%w; %v", outputErr, err.Error()) } diff --git a/shared/http.go b/shared/http.go index 5b84f39f5..6be311b52 100644 --- a/shared/http.go +++ b/shared/http.go @@ -3,22 +3,56 @@ package shared import ( "bytes" "io" + "mime" "net/http" + "strings" ) // CloneHTTPRequest clones http request func CloneHTTPRequest(request *http.Request) (*http.Request, error) { - var data []byte - var err error + // Shallow clone; special-case multipart to avoid buffering entire body ret := *request ret.URL = request.URL - if request.Body != nil { - if data, err = readRequestBody(request); err != nil { - return nil, err + + if request.Body == nil { + return &ret, nil + } + + // Detect multipart/*; avoid reading/consuming body + if IsMultipartRequest(request) { + // If multipart form has already been parsed, we don't need to + // share or re-read the body. Instead, reuse the parsed form and + // multipart data on the clone so that downstream logic can access + // form values without touching the body again. + if request.MultipartForm != nil { + // Body is no longer needed for form access. + ret.Body = http.NoBody + // Reuse parsed forms and multipart metadata. + ret.MultipartForm = request.MultipartForm + if request.Form != nil { + ret.Form = request.Form + } + if request.PostForm != nil { + ret.PostForm = request.PostForm + } + + return &ret, nil } - ret.Body = io.NopCloser(bytes.NewReader(data)) + + // Backwards compatibility: if the multipart form hasn't been + // parsed yet, fall back to sharing the body. Callers must + // still ensure only one reader consumes it. + ret.Body = request.Body + return &ret, nil } - return &ret, err + + // Non-multipart: safe full read, reset both original and clone bodies + data, err := readRequestBody(request) + if err != nil { + return nil, err + } + ret.Body = io.NopCloser(bytes.NewReader(data)) + return &ret, nil } func readRequestBody(request *http.Request) ([]byte, error) { @@ -30,3 +64,28 @@ func readRequestBody(request *http.Request) ([]byte, error) { request.Body = io.NopCloser(bytes.NewReader(data)) return data, err } + +// IsMultipartRequest returns true if request Content-Type is multipart/* +func IsMultipartRequest(r *http.Request) bool { + if r == nil || r.Header == nil { + return false + } + return IsMultipartContentType(r.Header.Get("Content-Type")) +} + +// IsMultipartContentType returns true when the Content-Type header indicates any multipart/* +func IsMultipartContentType(ct string) bool { + if ct == "" { + return false + } + mediaType, _, err := mime.ParseMediaType(ct) + if err != nil { + return strings.Contains(strings.ToLower(ct), "multipart/") + } + return strings.HasPrefix(strings.ToLower(mediaType), "multipart/") +} + +// IsFormData returns true when mediaType equals multipart/form-data +func IsFormData(mediaType string) bool { + return strings.EqualFold(mediaType, "multipart/form-data") +} diff --git a/shared/logging/logger.go b/shared/logging/logger.go new file mode 100644 index 000000000..de26bc87d --- /dev/null +++ b/shared/logging/logger.go @@ -0,0 +1,441 @@ +package logging + +import ( + "context" + "encoding/json" + "fmt" + "io" + "log/slog" + "os" + regexp "regexp" + "runtime" + strings "strings" + + "github.com/aws/aws-lambda-go/events" + "github.com/viant/xdatly/handler/exec" + "github.com/viant/xdatly/handler/logger" +) + +const ( + ReqId = "RequestId" + OpenTelemetryTraceId = "OpenTelemetryTraceId" + DEBUG = "DEBUG" + INFO = "INFO" + WARN = "WARN" + ERROR = "ERROR" + UNKNOWN = "UNKNOWN" // Indicate other environment + DefaultTraceIdKey = "reqTraceId" +) + +type slogger struct { + logger *slog.Logger + level slog.Level + traceIdKey string +} + +type Option func(l *slogger) + +func WithTraceIdKey(key string) Option { + return func(l *slogger) { + l.traceIdKey = key + } +} + +// Init creates an ISLogger instance, a structured logger using the JSON Handler. +// Creating this logger sets this as the default logger, so any logging after this +// which goes through the standard logging package will also produce JSON structured +// logs. +func New(level string, dest io.Writer, opts ...Option) logger.Logger { + if dest == nil { + dest = os.Stdout + } + + logLevel := slog.LevelInfo + switch strings.ToUpper(level) { + case DEBUG: + logLevel = slog.LevelDebug + case WARN: + logLevel = slog.LevelWarn + case ERROR: + logLevel = slog.LevelError + } + + handler := slog.NewJSONHandler(dest, &slog.HandlerOptions{ + AddSource: false, + Level: logLevel, + ReplaceAttr: func(groups []string, a slog.Attr) slog.Attr { + // Rename the time key to "timestamp" + if a.Key == slog.TimeKey { + a.Key = "timestamp" + } + return a + }, + }) + sl := slog.New(handler) + slog.SetDefault(sl) + l := &slogger{sl, logLevel, DefaultTraceIdKey} + for _, opt := range opts { + opt(l) + } + + return l +} + +func (s *slogger) IsDebugEnabled() bool { + return s.level.Level() <= slog.LevelDebug +} + +func (s *slogger) IsInfoEnabled() bool { + return s.level.Level() <= slog.LevelInfo +} + +func (s *slogger) IsWarnEnabled() bool { + return s.level.Level() <= slog.LevelWarn +} + +func (s *slogger) IsErrorEnabled() bool { + return s.level.Level() <= slog.LevelError +} + +// getCallerInfo uses runtime to get the caller's program counter +// and extract info from the stack frame to get the function name, etc. +func (s *slogger) getCallerInfo() []any { + callers := make([]uintptr, 1) + count := runtime.Callers(3, callers[:]) // skip to actual caller + if count == 0 { + slog.Warn("getCallerInfo: no frames, exiting") + return nil + } + + frames := runtime.CallersFrames(callers) + var frame runtime.Frame + var more bool + for { + frame, more = frames.Next() + if !more { + break + } + } + + attr := []any{ + "function", frame.Function, "file", frame.File, "line", frame.Line, + } + + return attr +} + +// getContextValues retrieves "known" logging values from the Context. +// These values can be added to the Context using the provided utility functions. +func (s *slogger) getContextValues(ctx context.Context) []any { + var values []any + if ctx == nil { + slog.Warn("getContextValues: ctx is nil") + return nil + } + + openTelemetryTraceId := ctx.Value(OpenTelemetryTraceId) + if openTelemetryTraceId != nil { + values = append(values, "OpenTelemetryTraceId", openTelemetryTraceId) + } + + execContext := exec.GetContext(ctx) + if execContext != nil { + traceId := "unknown" + + // ideally TraceID and Trace.TraceID should be the same + // but xdatly/handler/exec.(*Context).setHeader sets TraceID first + // with the value of XDATLY_TRACING_HEADER env var value header (adp-request-id for datly platform) + if execContext.TraceID != "" { + traceId = execContext.TraceID + } else if execContext.Trace != nil { + traceId = execContext.Trace.TraceID + } + values = append(values, s.traceIdKey, traceId) + } + + return values +} + +// Info wraps a call to slog.Info, inserting details for the calling function. +func (s *slogger) Info(msg string, args ...any) { + if !s.IsInfoEnabled() { + return + } + caller := s.getCallerInfo() + caller = append(caller, args...) + s.logger.Info(msg, caller...) +} + +// Debug wraps a call to slog.Debug, inserting details for the calling function. +func (s *slogger) Debug(msg string, args ...any) { + if !s.IsDebugEnabled() { + return + } + caller := s.getCallerInfo() + caller = append(caller, args...) + s.logger.Debug(msg, caller...) +} + +// Warn wraps a call to slog.Warn, inserting details for the calling function. +func (s *slogger) Warn(msg string, args ...any) { + if !s.IsWarnEnabled() { + return + } + caller := s.getCallerInfo() + caller = append(caller, args...) + s.logger.Warn(msg, caller...) +} + +// Error wraps a call to slog.Error, inserting details for the calling function. +func (s *slogger) Error(msg string, args ...any) { + if !s.IsErrorEnabled() { + return + } + caller := s.getCallerInfo() + caller = append(caller, args...) + s.logger.Error(msg, caller...) +} + +// Infoc wraps a call to slog.Info, inserting details for the calling function, +// and retrieving known values from the context object. +func (s *slogger) Infoc(ctx context.Context, msg string, args ...any) { + if !s.IsInfoEnabled() { + return + } + caller := s.getCallerInfo() + values := s.getContextValues(ctx) + caller = append(caller, values...) + caller = append(caller, args...) + s.logger.Info(msg, caller...) +} + +func (s *slogger) Infos(ctx context.Context, msg string, attrs ...slog.Attr) { + if !s.IsInfoEnabled() { + return + } + caller := s.getCallerInfo() + values := s.getContextValues(ctx) + caller = append(caller, values...) + caller = append(caller, redactAttrs(attrs...)...) + + s.logger.Info(msg, caller...) +} + +// Debugc wraps a call to slog.Debug, inserting details for the calling function, +// and retrieving known values from the context object. +func (s *slogger) Debugc(ctx context.Context, msg string, args ...any) { + if !s.IsDebugEnabled() { + return + } + caller := s.getCallerInfo() + values := s.getContextValues(ctx) + caller = append(caller, values...) + caller = append(caller, args...) + s.logger.Debug(msg, caller...) +} + +func (s *slogger) Debugs(ctx context.Context, msg string, attrs ...slog.Attr) { + if !s.IsDebugEnabled() { + return + } + caller := s.getCallerInfo() + values := s.getContextValues(ctx) + caller = append(caller, values...) + caller = append(caller, redactAttrs(attrs...)...) + + s.logger.Debug(msg, caller...) +} + +// DebugJSONc wraps a call to slog.Debug, inserting details for the calling function, +// and retrieving known values from the context object. +func (s *slogger) DebugJSONc(ctx context.Context, msg string, obj any) { + caller := s.getCallerInfo() + values := s.getContextValues(ctx) + caller = append(caller, values...) + + // Initialize request and jsonData variables + var request events.APIGatewayProxyRequest + var jsonData []byte + // Marshal the object to JSON string + jsonString, _ := json.Marshal(obj) + // Unmarshal JSON string to APIGatewayProxyRequest + err := json.Unmarshal(jsonString, &request) + if err != nil { + return + } + + // Check if the request has an HTTP method + if len(request.HTTPMethod) > 0 { + if request.MultiValueHeaders == nil { + request.MultiValueHeaders = make(map[string][]string) + } + // Remove Authorization header + request.MultiValueHeaders["Authorization"] = nil + request.Headers["Authorization"] = "" + // Marshal the modified request to JSON + jsonData, _ = json.Marshal(request) + } else { + jsonData = jsonString + } + msg = fmt.Sprintf("%s %s", msg, string(jsonData)) + s.Debugc(ctx, msg, caller...) +} + +// Warnc wraps a call to slog.Warn, inserting details for the calling function, +// and retrieving known values from the context object. +func (s *slogger) Warnc(ctx context.Context, msg string, args ...any) { + if !s.IsWarnEnabled() { + return + } + caller := s.getCallerInfo() + values := s.getContextValues(ctx) + caller = append(caller, values...) + caller = append(caller, args...) + s.logger.Warn(msg, caller...) +} + +func (s *slogger) Warns(ctx context.Context, msg string, attrs ...slog.Attr) { + if !s.IsWarnEnabled() { + return + } + caller := s.getCallerInfo() + values := s.getContextValues(ctx) + caller = append(caller, values...) + caller = append(caller, redactAttrs(attrs...)...) + + s.logger.Warn(msg, caller...) +} + +// Errorc wraps a call to slog.Error, inserting details for the calling function, +// and retrieving known values from the context object. +func (s *slogger) Errorc(ctx context.Context, msg string, args ...any) { + if !s.IsErrorEnabled() { + return + } + caller := s.getCallerInfo() + values := s.getContextValues(ctx) + caller = append(caller, values...) + caller = append(caller, args...) + s.logger.Error(msg, caller...) +} + +func (s *slogger) Errors(ctx context.Context, msg string, attrs ...slog.Attr) { + if !s.IsErrorEnabled() { + return + } + caller := s.getCallerInfo() + values := s.getContextValues(ctx) + caller = append(caller, values...) + caller = append(caller, redactAttrs(attrs...)...) + + s.logger.Error(msg, caller...) +} + +// Helper to get platform from environment suffix +func getPlatformFromEnv(environment string) string { + switch { + case strings.Contains(environment, "dev"): + return "development" + case strings.Contains(environment, "stage"): + return "stage" + case strings.Contains(environment, "prod"): + return "production" + default: + return UNKNOWN + } +} + +// redactAttrs applies redaction rules to slog.Attr list. +// Skip redactValue for primitive types to avoid unnecessary processing +// This avoids redundant type switch/marshalling cost in high-volume logging +func redactAttrs(attrs ...slog.Attr) []any { + var result []any + for _, attr := range attrs { + if isSensitiveKey(attr.Key) { + result = append(result, slog.String(attr.Key, "[REDACTED]")) + continue + } + val := attr.Value.Any() + switch val.(type) { + case int, int64, float64, bool, nil: + result = append(result, attr) + default: + redactedValue := slog.AnyValue(redactValue(val)) + result = append(result, slog.Attr{Key: attr.Key, Value: redactedValue}) + } + } + return result +} + +// redactValue recursively redacts sensitive info in maps, slices, or structs. +func redactValue(value any) any { + switch v := value.(type) { + case string: + // Redact sensitive information in strings + return redactSensitiveInfo(v) + case int, int64, float64, bool, nil: + // Return primitive values directly (skip JSON marshalling) + return v + case map[string]any: + // Redact value if key is sensitive (e.g., Authorization → [REDACTED]) + // Ensures map fields are redacted even if value doesn’t match regex + for key, val := range v { + if isSensitiveKey(key) { + v[key] = "[REDACTED]" + } else { + v[key] = redactValue(val) + } + } + return v + case []any: + // Recursively redact sensitive information in slices + for i, val := range v { + v[i] = redactValue(val) + } + return v + default: + // Only marshal/unmarshal if absolutely needed (structs, unknown). + jsonData, err := json.Marshal(v) // Converts struct to map to enable nested field redaction. + if err != nil { + return v // If marshal fails, skip redaction + } + var unmarshaled any + if err := json.Unmarshal(jsonData, &unmarshaled); err != nil { + return v // If unmarshal fails, skip redaction + } + return redactValue(unmarshaled) + } +} + +// redactSensitiveInfo redacts known patterns in a string (e.g., tokens in URLs). +func redactSensitiveInfo(value string) string { + sensitivePatterns := []*regexp.Regexp{ + // Redact key=value style + regexp.MustCompile(`(?i)(X-Amz-Security-Token|X-Amz-Signature|X-Amz-Credential|Authorization|password|token|apiKey)=([^&\s]+)`), + // Redact key: value or key value + regexp.MustCompile(`(?i)(Authorization|password|token|apiKey)[\s:=]+([^&\s]+)`), + // Redact URL with user:pass@host + regexp.MustCompile(`(?i)https?://[^/]+:[^@]+@`), + } + + redacted := value + for _, pattern := range sensitivePatterns { + redacted = pattern.ReplaceAllString(redacted, "$1=[REDACTED]") + } + return redacted +} + +// isSensitiveKey returns true if the key is known to contain sensitive data. +func isSensitiveKey(key string) bool { + sensitiveKeys := []string{ + "authorization", "token", "apikey", "password", + "credential", "secret", "access_key", "secret_key", + } + key = strings.ToLower(key) + for _, sk := range sensitiveKeys { + if key == sk { + return true + } + } + return false +} diff --git a/utils/errors/db.go b/utils/errors/db.go new file mode 100644 index 000000000..19067f0ce --- /dev/null +++ b/utils/errors/db.go @@ -0,0 +1,50 @@ +package errors + +import ( + "errors" + "strings" +) + +// IsDatabaseError determines whether the supplied error was caused by the database or driver layer. +// We inspect the full error chain because many call-sites wrap driver errors with additional context. +func IsDatabaseError(err error) bool { + if err == nil { + return false + } + return hasDatabaseSignature(err) +} + +func hasDatabaseSignature(err error) bool { + for err != nil { + if matchesDatabasePattern(err.Error()) { + return true + } + err = errors.Unwrap(err) + } + return false +} + +func matchesDatabasePattern(message string) bool { + if message == "" { + return false + } + lower := strings.ToLower(message) + patterns := []string{ + "database error occured while fetching data", + "database error occurred while fetching data", + "error occured while connecting to database", + "error occurred while connecting to database", + "failed to get db", + "failed to create stmt source", + "too many connections", + "connection refused", + "driver: bad connection", + "sql: transaction has already been committed or rolled back", + } + for _, pattern := range patterns { + if strings.Contains(lower, pattern) { + return true + } + } + return false +} diff --git a/utils/httputils/violation.go b/utils/httputils/violation.go index 7ff1c0a35..912c7f19c 100644 --- a/utils/httputils/violation.go +++ b/utils/httputils/violation.go @@ -50,7 +50,7 @@ func (v Violations) MergeErrors(errors []*response.Error) validator.Violations { aViolation := &validator.Violation{ Location: anError.View + "/" + anError.Parameter, Value: anError.Object, - Check: fmt.Sprint("%T", anError.Error()), + Check: fmt.Sprintf("%T", anError.Error()), Message: anError.Message, } ret = append(ret, aViolation) diff --git a/utils/types/types.go b/utils/types/types.go index dc29f1235..8e7ccd784 100644 --- a/utils/types/types.go +++ b/utils/types/types.go @@ -1,6 +1,7 @@ package types import ( + "fmt" "github.com/viant/sqlx/io" "github.com/viant/xreflect" "reflect" @@ -11,6 +12,9 @@ func LookupType(lookup xreflect.LookupType, typeName string, opts ...xreflect.Op if ok { return rType, nil } + if lookup == nil { + return nil, fmt.Errorf("type %q was not found and no lookup resolver is configured", typeName) + } return lookup(typeName, opts...) } diff --git a/view/column/discover.go b/view/column/discover.go index f78604699..48c7b221a 100644 --- a/view/column/discover.go +++ b/view/column/discover.go @@ -248,7 +248,9 @@ func parseQuery(SQL string) (string, string, sqlparser.Columns) { if sqlQuery.From.X != nil { table = sqlparser.Stringify(sqlQuery.From.X) } - if sqlQuery.List.IsStarExpr() && !strings.Contains(table, "SELECT") { + // For CTE-backed queries (WITH ...), SELECT * FROM cte_alias must still be + // resolved via SQL execution; the alias is not a physical table. + if sqlQuery.List.IsStarExpr() && !strings.Contains(table, "SELECT") && len(sqlQuery.WithSelects) == 0 { return table, "", nil //use table metadata } sqlQuery.Limit = nil diff --git a/view/column/discover_test.go b/view/column/discover_test.go new file mode 100644 index 000000000..ef2cf966e --- /dev/null +++ b/view/column/discover_test.go @@ -0,0 +1,19 @@ +package column + +import ( + "strings" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestParseQuery_WithCTEStar_DoesNotShortCircuitToTableMetadata(t *testing.T) { + sql := `WITH cte AS (SELECT 1 AS a) SELECT v.* FROM cte v` + + table, discoveredSQL, cols := parseQuery(sql) + require.Equal(t, "cte", strings.TrimSpace(table)) + require.NotEmpty(t, cols) + require.NotEmpty(t, strings.TrimSpace(discoveredSQL), "CTE star query must keep SQL for runtime column inference") + require.Contains(t, strings.ToUpper(discoveredSQL), "WITH CTE AS") + require.Contains(t, discoveredSQL, "LIMIT 1") +} diff --git a/view/config.go b/view/config.go index 485fe5f45..5dd70f6de 100644 --- a/view/config.go +++ b/view/config.go @@ -3,12 +3,13 @@ package view import ( "context" "fmt" + "reflect" + "strings" + "github.com/viant/datly/shared" "github.com/viant/datly/view/state" "github.com/viant/xdatly/codec" "github.com/viant/xreflect" - "reflect" - "strings" ) const ( @@ -90,6 +91,11 @@ func (c *Config) GetContentFormatParameter() *state.Parameter { return QueryStateParameters.ContentFormatParameter } +func (c *Constraints) HasOrderByColumn(name string) bool { + _, ok := c.OrderByColumn[name] + return ok +} + func (c *Config) Init(ctx context.Context, resource *Resource, parent *View) error { if err := c.ensureConstraints(resource); err != nil { return err diff --git a/view/extension/handler/loader.go b/view/extension/handler/loader.go index 7dd4fc96f..86a5966c2 100644 --- a/view/extension/handler/loader.go +++ b/view/extension/handler/loader.go @@ -6,6 +6,7 @@ import ( "compress/gzip" "context" "encoding/json" + "errors" "fmt" "github.com/viant/afs" "github.com/viant/datly/utils/types" @@ -42,56 +43,92 @@ func (l *LoadData) Exec(ctx context.Context, session handler.Session) (interface if !ok || err != nil { return nil, fmt.Errorf("invalid Loader URL: %w", err) } + var URL string - switch URLValue.(type) { + switch v := URLValue.(type) { case string: - URL = URLValue.(string) + URL = v case *string: - URL = *URLValue.(*string) + URL = *v default: - return nil, fmt.Errorf("invalid Loader URL: expected %T, but had %T", URL, URLValue) + return nil, fmt.Errorf("invalid Loader URL: expected %T, but had %T", "", URLValue) } + // Prefer .gz if the plain URL doesn't exist. if ok, _ := l.fs.Exists(ctx, URL); !ok { if ok, _ := l.fs.Exists(ctx, URL+".gz"); ok { URL += ".gz" } } - isCompressed := strings.HasSuffix(URL, ".gz") + // Download compressed or plain bytes (API returns []byte). data, err := l.fs.DownloadWithURL(ctx, URL) if err != nil { return nil, fmt.Errorf("failed to load URL: %w", err) } - if isCompressed { - reader, err := gzip.NewReader(bytes.NewReader(data)) + + // Build a streaming reader chain; avoid io.ReadAll on gzip. + var r io.Reader = bytes.NewReader(data) + if strings.HasSuffix(URL, ".gz") { + gzr, err := gzip.NewReader(r) if err != nil { return nil, fmt.Errorf("failed to decompress URL: failed to create reader: %w (used URL: %s)", err, URL) } - defer reader.Close() - if data, err = io.ReadAll(reader); err != nil { - return nil, fmt.Errorf("failed to decompress URL:%w (used URL: %s)", err, URL) - } + defer gzr.Close() + r = gzr } + + br := bufio.NewReaderSize(r, 1<<20) // read-ahead; does NOT cap JSON size + dec := json.NewDecoder(br) + dec.UseNumber() + + // Output slice + appender (kept from your original design) itemType := l.Options.OutputType.Elem() xSlice := xunsafe.NewSlice(l.Options.OutputType) - scanner := bufio.NewScanner(bytes.NewReader(data)) response := reflect.New(l.Options.OutputType).Interface() appender := xSlice.Appender(xunsafe.AsPointer(response)) - scanner.Buffer(make([]byte, 1024*1024), 5*1024*1024) - for scanner.Scan() { - line := scanner.Bytes() - if len(line) == 0 { - continue + + // Reject top-level arrays to keep the code simple (no streaming array parsing). + first, err := peekFirstNonSpace(br) + if err != nil { + if errors.Is(err, io.EOF) { + return response, nil // empty file -> empty slice } - item := types.NewValue(itemType) - err := json.Unmarshal(scanner.Bytes(), item) - if err != nil { - return nil, fmt.Errorf("invalid item: %w, %s", err, line) + return nil, fmt.Errorf("read error: %w", err) + } + if first == '[' { + return nil, fmt.Errorf("top-level JSON arrays are not supported; provide NDJSON (one object per line) or a single JSON object") + } + // Put the byte back so the decoder sees it. + _ = br.UnreadByte() + + // Decode one value per call: supports single object or NDJSON. + for { + item := types.NewValue(itemType) // pointer to zero value of element type + if err := dec.Decode(item); err != nil { + if errors.Is(err, io.EOF) { + break + } + return nil, fmt.Errorf("invalid item: %w", err) } appender.Append(item) } - return response, scanner.Err() + + return response, nil +} + +// Reads and returns the first non-space byte without consuming input for the decoder. +func peekFirstNonSpace(br *bufio.Reader) (byte, error) { + for { + b, err := br.ReadByte() + if err != nil { + return 0, err + } + if b == ' ' || b == '\n' || b == '\r' || b == '\t' { + continue + } + return b, nil + } } func (*LoadDataProvider) New(ctx context.Context, opts ...handler.Option) (handler.Handler, error) { diff --git a/view/extension/init.go b/view/extension/init.go index db515af42..ea5ff3730 100644 --- a/view/extension/init.go +++ b/view/extension/init.go @@ -3,6 +3,9 @@ package extension import ( "encoding/json" "fmt" + "mime/multipart" + "net/http" + dcodec "github.com/viant/datly/view/extension/codec" "github.com/viant/datly/view/extension/handler" "github.com/viant/datly/view/extension/marshaller" @@ -17,14 +20,14 @@ import ( "github.com/viant/xdatly/handler/response/tabular/tjson" "github.com/viant/xdatly/handler/response/tabular/xml" "github.com/viant/xdatly/handler/validator" - "net/http" + + "reflect" + "time" "github.com/viant/xdatly/predicate" "github.com/viant/xdatly/types/core" _ "github.com/viant/xdatly/types/custom" "github.com/viant/xreflect" - "reflect" - "time" ) const ( @@ -50,7 +53,8 @@ func InitRegistry() { xreflect.NewType("validator.Violation", xreflect.WithReflectType(reflect.TypeOf(validator.Violation{}))), xreflect.NewType("RawMessage", xreflect.WithReflectType(reflect.TypeOf(json.RawMessage{}))), xreflect.NewType("json.RawMessage", xreflect.WithReflectType(reflect.TypeOf(json.RawMessage{}))), - xreflect.NewType("json.RawMessage", xreflect.WithReflectType(reflect.TypeOf(json.RawMessage{}))), + xreflect.NewType("FileHeader", xreflect.WithReflectType(reflect.TypeOf(multipart.FileHeader{}))), + xreflect.NewType("multipart.FileHeader", xreflect.WithReflectType(reflect.TypeOf(multipart.FileHeader{}))), xreflect.NewType("types.BitBool", xreflect.WithReflectType(reflect.TypeOf(types.BitBool(true)))), xreflect.NewType("time.Time", xreflect.WithReflectType(xreflect.TimeType)), xreflect.NewType("response.Status", xreflect.WithReflectType(reflect.TypeOf(response.Status{}))), @@ -119,6 +123,7 @@ func InitRegistry() { PredicateGreaterOrEqual: NewGreaterOrEqualPredicate(), PredicateGreaterThan: NewGreaterThanPredicate(), PredicateLike: NewLikePredicate(), + PredicateExpr: NewExprPredicate(), PredicateNotLike: NewNotLikePredicate(), PredicateHandler: NewPredicateHandler(), PredicateContains: NewContainsPredicate(), diff --git a/view/extension/predicates.go b/view/extension/predicates.go index 04b0ba098..ae0276e95 100644 --- a/view/extension/predicates.go +++ b/view/extension/predicates.go @@ -2,12 +2,13 @@ package extension import ( "fmt" + "sync" + "github.com/viant/datly/utils/types" codec2 "github.com/viant/datly/view/extension/codec" "github.com/viant/xdatly/codec" "github.com/viant/xdatly/predicate" "github.com/viant/xreflect" - "sync" ) const ( @@ -32,6 +33,7 @@ const ( PredicateExists = "exists" PredicateNotExists = "not_exists" + PredicateExpr = "expr" PredicateCriteriaExists = "exists_criteria" PredicateCriteriaNotExists = "not_exists_criteria" PredicateCriteriaIn = "in_criteria" @@ -225,6 +227,10 @@ func NewEqualPredicate() *Predicate { return binaryPredicate(PredicateEqual, "=") } +func NewColumnExpressionPredicate() *Predicate { + return binaryPredicate(PredicateEqual, "=") +} + func NewLessOrEqualPredicate() *Predicate { return binaryPredicate(PredicateLessOrEqual, "<=") } @@ -333,6 +339,10 @@ func NewLikePredicate() *Predicate { return newLikePredicate(PredicateLike, true) } +func NewExprPredicate() *Predicate { + return newExprPredicate(PredicateExpr) +} + func NewNotLikePredicate() *Predicate { return newLikePredicate(PredicateNotLike, false) } @@ -362,6 +372,23 @@ func newLikePredicate(name string, inclusive bool) *Predicate { } } +func newExprPredicate(expr string) *Predicate { + args := []*predicate.NamedArgument{ + { + Name: "Expression", + Position: 0, + }, + } + criteria := fmt.Sprintf(`$criteria.Expression($Expression, $FilterValue)`) + return &Predicate{ + Template: &predicate.Template{ + Name: expr, + Source: " " + criteria, + Args: args, + }, + } +} + func NewContainsPredicate() *Predicate { return newContainsPredicate(PredicateContains, true) } diff --git a/view/predicate.go b/view/predicate.go index 69eb438f4..415550c28 100644 --- a/view/predicate.go +++ b/view/predicate.go @@ -3,6 +3,10 @@ package view import ( "context" "fmt" + "reflect" + "strings" + "sync" + expand "github.com/viant/datly/service/executor/expand" "github.com/viant/datly/utils/types" "github.com/viant/datly/view/extension" @@ -12,9 +16,6 @@ import ( "github.com/viant/xdatly/predicate" "github.com/viant/xreflect" "github.com/viant/xunsafe" - "reflect" - "strings" - "sync" ) type ( @@ -34,6 +35,7 @@ type ( state *expand.NamedVariable hasStateName *expand.NamedVariable handler codec.PredicateHandler + stateType *structology.StateType } PredicateEvaluator struct { @@ -41,6 +43,7 @@ type ( evaluator *expand.Evaluator valueState *expand.NamedVariable hasValueState *expand.NamedVariable + stateType *structology.StateType } ) @@ -51,20 +54,34 @@ func (e *PredicateEvaluator) Compute(ctx context.Context, value interface{}) (*c } val := ctx.Value(expand.PredicateState) - aState := val.(*structology.State) - offset := len(cuxtomCtx.DataUnit.ParamsGroup) - evaluate, err := e.Evaluate(cuxtomCtx, aState, value) + var aState *structology.State + if s, ok := val.(*structology.State); ok { + aState = s + } + if aState == nil && e.stateType != nil { + // Initialize state if absent; do not override if provided. + aState = e.stateType.NewState() + } + // evaluate predicate with an isolated DataUnit to avoid + // mutating parent DataUnit and relying on Shrink/restore across nesting. + var metaSource expand.Dber + if cuxtomCtx.DataUnit != nil { + metaSource = cuxtomCtx.DataUnit.MetaSource + } + isolatedDU := expand.NewDataUnit(metaSource) + tmpCtx := *cuxtomCtx + tmpCtx.DataUnit = isolatedDU + + evaluate, err := e.Evaluate(&tmpCtx, aState, value) if err != nil { return nil, err } - placeholderLen := len(evaluate.DataUnit.ParamsGroup) - offset - var values = make([]interface{}, placeholderLen) - if placeholderLen > 0 { - copy(values, evaluate.DataUnit.ParamsGroup[offset:]) - } + // Collect placeholders from the isolated DataUnit and return them + // to the caller; do not mutate the parent DataUnit here. + values := make([]interface{}, len(isolatedDU.ParamsGroup)) + copy(values, isolatedDU.ParamsGroup) criteria := &codec.Criteria{Expression: evaluate.Buffer.String(), Placeholders: values} - cuxtomCtx.DataUnit.ParamsGroup = cuxtomCtx.DataUnit.ParamsGroup[:offset] return criteria, nil } @@ -150,6 +167,7 @@ func (p *predicateEvaluatorProvider) new(predicateConfig *extension.PredicateCon evaluator: p.evaluator, valueState: p.state, hasValueState: p.hasStateName, + stateType: p.stateType, }, nil } @@ -207,5 +225,6 @@ func (p *predicateEvaluatorProvider) init(resource *Resource, predicateConfig *e p.signature = argsIndexed p.state = stateVariable p.hasStateName = hasVariable + p.stateType = stateType return nil } diff --git a/view/resource.go b/view/resource.go index 94ca4db9a..b1867d19e 100644 --- a/view/resource.go +++ b/view/resource.go @@ -72,7 +72,8 @@ type ( Substitutes Substitutes Docs *Documentation - FSEmbedder *state.FSEmbedder + + FSEmbedder *state.FSEmbedder modTime time.Time _doc docs.Service @@ -152,6 +153,17 @@ func (r *Resource) ReverseSubstitutes(text string) string { return r.Substitutes.ReverseReplace(text) } +func (r *Resource) EmbedFS() *embed.FS { + if r.FSEmbedder == nil { + return nil + } + return r.FSEmbedder.EmbedFS() +} + +func (r *Resource) SetFSEmbedder(embedder *state.FSEmbedder) { + r.FSEmbedder = embedder +} + func (r *Resource) SetFs(fs afs.Service) { r.fs = fs } @@ -566,7 +578,12 @@ func LoadResourceFromURL(ctx context.Context, URL string, fs afs.Service) (*Reso resource := &Resource{} err = toolbox.DefaultConverter.AssignConverted(resource, aMap) if err != nil { - return nil, err + if docs, ok := parseDocumentationOnlyResource(aMap); ok { + resource.Docs = &Documentation{Docs: docs} + err = nil + } else { + return nil, err + } } resource.fs = fs resource.SourceURL = URL @@ -574,6 +591,76 @@ func LoadResourceFromURL(ctx context.Context, URL string, fs afs.Service) (*Reso return resource, err } +func parseDocumentationOnlyResource(source map[string]interface{}) (*state.Docs, bool) { + if len(source) == 0 { + return nil, false + } + + // Route-like YAML should never be treated as dependency docs. + for _, key := range []string{"Routes", "Method", "URI", "Input", "Output", "View", "Handler", "Resource"} { + if _, ok := source[key]; ok { + return nil, false + } + } + + // If clear resource keys exist, keep regular conversion semantics. + for _, key := range []string{"Connectors", "Views", "Types", "Substitutes", "MessageBuses", "CacheProviders", "Loggers", "Predicates", "Docs", "FSEmbedder", "Imports"} { + if _, ok := source[key]; ok { + return nil, false + } + } + + ret := &state.Docs{} + found := false + for _, section := range []struct { + name string + dest *state.Documentation + }{ + {name: "Parameters", dest: &ret.Parameters}, + {name: "Columns", dest: &ret.Columns}, + {name: "Paths", dest: &ret.Paths}, + {name: "Filter", dest: &ret.Filter}, + } { + raw, ok := source[section.name] + if !ok { + continue + } + doc, ok := asDocumentation(raw) + if !ok { + return nil, false + } + *section.dest = doc + found = true + } + if !found { + return nil, false + } + return ret, true +} + +func asDocumentation(raw interface{}) (state.Documentation, bool) { + aMap, ok := raw.(map[string]interface{}) + if !ok { + return nil, false + } + ret := state.Documentation{} + for key, value := range aMap { + switch actual := value.(type) { + case string: + ret[key] = actual + case map[string]interface{}: + nested, ok := asDocumentation(actual) + if !ok { + return nil, false + } + ret[key] = nested + default: + return nil, false + } + } + return ret, true +} + func (r *Resource) FindConnector(view *View) (*Connector, error) { if view.Connector == nil { var connector *Connector diff --git a/view/state.go b/view/state.go index 3484de7f8..a187cd4b5 100644 --- a/view/state.go +++ b/view/state.go @@ -1,12 +1,14 @@ package view import ( + "strings" + "sync" + "github.com/viant/datly/view/state/predicate" "github.com/viant/sqlx/io/read/cache" "github.com/viant/structology" "github.com/viant/tagly/format/text" - "strings" - "sync" + "github.com/viant/xdatly/handler/state" ) // Statelet allows customizing View fetched from Database @@ -14,9 +16,18 @@ type ( //InputType represents view state Statelet struct { - Template *structology.State - QuerySelector + //SELECTORS + DatabaseFormat text.CaseFormat + OutputFormat text.CaseFormat + Template *structology.State + state.QuerySelector QuerySettings + filtersMu sync.Mutex + initialized bool + _columnNames map[string]bool + result *cache.ParmetrizedQuery + predicate.Filters + Ignore bool } QuerySettings struct { @@ -24,44 +35,11 @@ type ( SyncFlag bool ContentFormat string } - - QuerySelector struct { - //SELECTORS - DatabaseFormat text.CaseFormat - OutputFormat text.CaseFormat - Columns []string `json:",omitempty"` - Fields []string `json:",omitempty"` - OrderBy string `json:",omitempty"` - Offset int `json:",omitempty"` - Limit int `json:",omitempty"` - - Criteria string `json:",omitempty"` - Placeholders []interface{} `json:",omitempty"` - Page int - Ignore bool - predicate.Filters - - initialized bool - _columnNames map[string]bool - result *cache.ParmetrizedQuery - } ) -func (s *QuerySelector) CurrentLimit() int { - return s.Limit -} - -func (s *QuerySelector) CurrentOffset() int { - return s.Offset -} - -func (s *QuerySelector) CurrentPage() int { - return s.Page -} - // Init initializes Statelet func (s *Statelet) Init(aView *View) { - if aView != nil && s.Template == nil && aView.Template.stateType != nil { + if aView != nil && s.Template == nil && aView.Template != nil && aView.Template.stateType != nil { s.Template = aView.Template.stateType.NewState() } if s.initialized { @@ -71,12 +49,12 @@ func (s *Statelet) Init(aView *View) { } // Has checks if Field is present in Template.Columns -func (s *QuerySelector) Has(field string) bool { +func (s *Statelet) Has(field string) bool { _, ok := s._columnNames[field] return ok } -func (s *QuerySelector) Add(fieldName string, isHolder bool) { +func (s *Statelet) Add(fieldName string, isHolder bool) { toLower := strings.ToLower(fieldName) if _, ok := s._columnNames[toLower]; ok { return @@ -94,18 +72,21 @@ func (s *QuerySelector) Add(fieldName string, isHolder bool) { } } -func (s *QuerySelector) SetCriteria(expanded string, placeholders []interface{}) { - s.Criteria = expanded - s.Placeholders = placeholders +// AppendFilters safely appends filters to the selector's Filters to avoid data races. +func (s *Statelet) AppendFilters(filters predicate.Filters) { + if len(filters) == 0 { + return + } + s.filtersMu.Lock() + s.Filters = append(s.Filters, filters...) + s.filtersMu.Unlock() } // NewStatelet creates a selector func NewStatelet() *Statelet { return &Statelet{ - QuerySelector: QuerySelector{ - _columnNames: map[string]bool{}, - initialized: true, - }, + _columnNames: map[string]bool{}, + initialized: true, } } @@ -116,7 +97,7 @@ type State struct { } // QuerySelector returns query selector -func (s *State) QuerySelector(view *View) *QuerySelector { +func (s *State) QuerySelector(view *View) *state.QuerySelector { statelet := s.Lookup(view) if statelet == nil { return nil diff --git a/view/state/hook.go b/view/state/hook.go index 9ae5b49c4..e871f4d6d 100644 --- a/view/state/hook.go +++ b/view/state/hook.go @@ -1,6 +1,11 @@ package state -import "context" +import ( + "context" + + "github.com/viant/xdatly/handler/http" + "github.com/viant/xdatly/handler/state" +) // Initializer is an interface that should be implemented by any type that needs to be initialized type Initializer interface { @@ -11,3 +16,7 @@ type Initializer interface { type Finalizer interface { Finalize(ctx context.Context) error } + +type InjectorFinalizer interface { + Finalize(ctx context.Context, getInjector func(ctx context.Context, path http.Route) (state.Injector, error)) error +} diff --git a/view/state/kind/locator.go b/view/state/kind/locator.go index 57f765f0e..864d85b00 100644 --- a/view/state/kind/locator.go +++ b/view/state/kind/locator.go @@ -2,6 +2,8 @@ package kind import ( "context" + "reflect" + "github.com/viant/datly/view/state" ) @@ -9,7 +11,7 @@ import ( type Locator interface { //Value returns parameter value - Value(ctx context.Context, name string) (interface{}, bool, error) + Value(ctx context.Context, rType reflect.Type, name string) (interface{}, bool, error) //Names returns names of supported parameters Names() []string diff --git a/view/state/kind/locator/body.go b/view/state/kind/locator/body.go index fc41a49aa..e17af401b 100644 --- a/view/state/kind/locator/body.go +++ b/view/state/kind/locator/body.go @@ -3,13 +3,16 @@ package locator import ( "context" "fmt" + "mime" + "mime/multipart" + "net/http" + "reflect" + "sync" + "github.com/viant/datly/shared" "github.com/viant/datly/view/state/kind" "github.com/viant/structology" hstate "github.com/viant/xdatly/handler/state" - "net/http" - "reflect" - "sync" ) type Body struct { @@ -21,22 +24,39 @@ type Body struct { request *http.Request err error sync.Once + isMultipart bool } +const maxMultipartMemory = 32 << 20 // 32 MiB + func (r *Body) Names() []string { return nil } -func (r *Body) Value(ctx context.Context, name string) (interface{}, bool, error) { +func (r *Body) Value(ctx context.Context, rType reflect.Type, name string) (interface{}, bool, error) { var err error - r.Once.Do(func() { - var request *http.Request - request, r.err = shared.CloneHTTPRequest(r.request) - r.body, r.err = readRequestBody(request) - if len(r.body) > 0 { - r.err = r.ensureRequest() + r.initOnce() + var requestState *structology.State + + // Multipart handling + if r.isMultipart { + return r.handleMultipartValue(rType, name) + } + + if len(r.body) > 0 { + if r.requestState != nil && r.requestState.Type().Type() == rType { + requestState = r.requestState } - }) + if name == "" { + requestState, r.err = r.ensureRequest(rType) + } else { + requestState, r.err = r.ensureRequest(r.bodyType) + } + if r.err == nil { + r.requestState = requestState + } + } + if len(r.body) == 0 { return nil, false, nil } @@ -47,16 +67,82 @@ func (r *Body) Value(ctx context.Context, name string) (interface{}, bool, error return r.decodeBodyMap(ctx) } if name == "" { - return r.requestState.State(), true, nil + return requestState.State(), true, nil } - sel, err := r.requestState.Selector(name) + sel, err := requestState.Selector(name) if err != nil { return nil, false, err } - if !sel.Has(r.requestState.Pointer()) { + if !sel.Has(requestState.Pointer()) { + return nil, false, nil + } + return sel.Value(requestState.Pointer()), true, nil +} + +// initOnce initializes body locator state based on content type (multipart vs non-multipart) +func (r *Body) initOnce() { + r.Once.Do(func() { + // Multipart branch + if r.request != nil { + ct := r.request.Header.Get("Content-Type") + if shared.IsMultipartContentType(ct) { + r.isMultipart = true + if mediaType, _, err := mime.ParseMediaType(ct); err == nil && shared.IsFormData(mediaType) { + r.err = r.request.ParseMultipartForm(maxMultipartMemory) + if r.err == nil { + r.seedFormFromMultipart() + } + } + return + } + } + // Non-multipart: clone and read body safely + var request *http.Request + request, r.err = shared.CloneHTTPRequest(r.request) + r.body, r.err = readRequestBody(request) + }) +} + +// handleMultipartValue returns value for multipart/form-data content +func (r *Body) handleMultipartValue(rType reflect.Type, name string) (interface{}, bool, error) { + if r.err != nil { + return nil, false, r.err + } + if r.request == nil || r.request.MultipartForm == nil { return nil, false, nil } - return sel.Value(r.requestState.Pointer()), true, nil + if name == "" { + return nil, false, nil + } + // File destinations + if rType != nil { + // []*multipart.FileHeader + if rType.Kind() == reflect.Slice && rType.Elem() == reflect.TypeOf((*multipart.FileHeader)(nil)) { + files := r.request.MultipartForm.File[name] + if len(files) == 0 { + return nil, false, nil + } + return files, true, nil + } + // *multipart.FileHeader + if rType == reflect.TypeOf((*multipart.FileHeader)(nil)) { + files := r.request.MultipartForm.File[name] + if len(files) == 0 { + return nil, false, nil + } + return files[0], true, nil + } + } + // Textual parts + if r.request.MultipartForm.Value != nil { + if vs, ok := r.request.MultipartForm.Value[name]; ok && len(vs) > 0 { + if rType != nil && rType.Kind() == reflect.Slice && rType.Elem().Kind() == reflect.String { + return vs, true, nil + } + return vs[0], true, nil + } + } + return nil, false, nil } func (r *Body) decodeBodyMap(ctx context.Context) (interface{}, bool, error) { @@ -74,34 +160,45 @@ func (r *Body) decodeBodyMap(ctx context.Context) (interface{}, bool, error) { // NewBody returns body locator func NewBody(opts ...Option) (kind.Locator, error) { options := NewOptions(opts) - if options.BodyType == nil { - return nil, fmt.Errorf("body type was empty") - } if options.request == nil { return nil, fmt.Errorf("request was empty") } if options.Unmarshal == nil { return nil, fmt.Errorf("unmarshal was empty") } + // Allow missing BodyType only for multipart/* requests; otherwise keep existing requirement. + if options.BodyType == nil { + ct := "" + if options.request != nil && options.request.Header != nil { + ct = options.request.Header.Get("Content-Type") + } + isMultipart := false + if ct != "" { + isMultipart = shared.IsMultipartContentType(ct) + } + if !isMultipart { + return nil, fmt.Errorf("body type was empty") + } + } var ret = &Body{request: options.request, bodyType: options.BodyType, unmarshal: options.UnmarshalFunc(), form: options.Form} return ret, nil } -func (r *Body) ensureRequest() (err error) { - if r.bodyType == nil { - return nil +func (r *Body) ensureRequest(rType reflect.Type) (*structology.State, error) { + if rType == nil { + return nil, nil } - rType := r.bodyType if rType.Kind() == reflect.Map { - return nil + return nil, nil } - bodyType := structology.NewStateType(r.bodyType) - r.requestState = bodyType.NewState() - dest := r.requestState.StatePtr() - if err = r.unmarshal(r.body, dest); err == nil { - r.requestState.Sync() + bodyType := structology.NewStateType(rType) + requestState := bodyType.NewState() + dest := requestState.StatePtr() + err := r.unmarshal(r.body, dest) + if err == nil { + requestState.Sync() } - return err + return requestState, err } func (r *Body) updateQueryString(ctx context.Context, body interface{}) { @@ -136,3 +233,19 @@ func (r *Body) updateQueryString(ctx context.Context, body interface{}) { // Encode the query string and assign it back to the request's URL req.URL.RawQuery = q.Encode() } + +// isMultipartRequest checks content type for multipart/form-data +// removed: local isMultipartRequest; use shared.IsMultipartContentType instead + +// seedFormFromMultipart copies textual multipart values into shared form to avoid re-parsing later +func (r *Body) seedFormFromMultipart() { + if r.request == nil || r.request.MultipartForm == nil || r.form == nil { + return + } + for k, vs := range r.request.MultipartForm.Value { + if len(vs) == 0 { + continue + } + r.form.Set(k, vs...) + } +} diff --git a/view/state/kind/locator/constants.go b/view/state/kind/locator/constants.go index f6cd81eeb..516e7c4a2 100644 --- a/view/state/kind/locator/constants.go +++ b/view/state/kind/locator/constants.go @@ -3,6 +3,7 @@ package locator import ( "context" "github.com/viant/datly/view/state/kind" + "reflect" "sync" ) @@ -16,7 +17,7 @@ func (r *Constants) Names() []string { return nil } -func (r *Constants) Value(ctx context.Context, name string) (interface{}, bool, error) { +func (r *Constants) Value(ctx context.Context, _ reflect.Type, name string) (interface{}, bool, error) { if len(r.constants) > 0 { if value, ok := r.constants[name]; ok { return value, true, nil diff --git a/view/state/kind/locator/context.go b/view/state/kind/locator/context.go index d5fcd827c..33d0985c4 100644 --- a/view/state/kind/locator/context.go +++ b/view/state/kind/locator/context.go @@ -5,6 +5,7 @@ import ( "fmt" "github.com/viant/datly/view/state/kind" "github.com/viant/xdatly/handler/exec" + "reflect" ) type Context struct { @@ -14,7 +15,7 @@ func (v *Context) Names() []string { return nil } -func (v *Context) Value(ctx context.Context, name string) (interface{}, bool, error) { +func (v *Context) Value(ctx context.Context, _ reflect.Type, name string) (interface{}, bool, error) { rawValue := ctx.Value(exec.ContextKey) if rawValue == nil { diff --git a/view/state/kind/locator/cookie.go b/view/state/kind/locator/cookie.go index 804949592..117550935 100644 --- a/view/state/kind/locator/cookie.go +++ b/view/state/kind/locator/cookie.go @@ -5,6 +5,7 @@ import ( "fmt" "github.com/viant/datly/view/state/kind" "net/http" + "reflect" ) type Cookie struct { @@ -19,7 +20,7 @@ func (v *Cookie) Names() []string { return result } -func (v *Cookie) Value(ctx context.Context, name string) (interface{}, bool, error) { +func (v *Cookie) Value(ctx context.Context, _ reflect.Type, name string) (interface{}, bool, error) { for _, cookie := range v.cookies { if cookie.Name == name { return cookie.Value, true, nil diff --git a/view/state/kind/locator/data.go b/view/state/kind/locator/data.go index efcabfdb5..db35876c5 100644 --- a/view/state/kind/locator/data.go +++ b/view/state/kind/locator/data.go @@ -18,7 +18,7 @@ func (p *DataView) Names() []string { return nil } -func (p *DataView) Value(ctx context.Context, name string) (interface{}, bool, error) { +func (p *DataView) Value(ctx context.Context, _ reflect.Type, name string) (interface{}, bool, error) { aView, ok := p.Views[name] if !ok { return nil, false, fmt.Errorf("failed to lookup view: %v", name) diff --git a/view/state/kind/locator/env.go b/view/state/kind/locator/env.go index 05ccc5217..28b980e5d 100644 --- a/view/state/kind/locator/env.go +++ b/view/state/kind/locator/env.go @@ -4,6 +4,7 @@ import ( "context" "github.com/viant/datly/view/state/kind" "os" + "reflect" ) type Env struct { @@ -14,7 +15,7 @@ func (v *Env) Names() []string { return os.Environ() } -func (v *Env) Value(ctx context.Context, name string) (interface{}, bool, error) { +func (v *Env) Value(ctx context.Context, _ reflect.Type, name string) (interface{}, bool, error) { ret, ok := v.env[name] return ret, ok, nil } diff --git a/view/state/kind/locator/form.go b/view/state/kind/locator/form.go index 387815e79..2b9190533 100644 --- a/view/state/kind/locator/form.go +++ b/view/state/kind/locator/form.go @@ -2,29 +2,76 @@ package locator import ( "context" + "mime" + "mime/multipart" + "net/http" + "net/url" + "reflect" + "sync" + + "github.com/viant/datly/shared" "github.com/viant/datly/view/state/kind" "github.com/viant/xdatly/handler/state" - "net/http" ) type Form struct { form *state.Form request *http.Request + once sync.Once } func (r *Form) Names() []string { return nil } -func (r *Form) Value(ctx context.Context, name string) (interface{}, bool, error) { +func (r *Form) Value(ctx context.Context, rType reflect.Type, name string) (interface{}, bool, error) { if r.form != nil && len(r.form.Values) == 0 && r.request == nil { return nil, false, nil } + + // Support file uploads when parameters are declared with kind=form + // and types *multipart.FileHeader or []*multipart.FileHeader. This + // aligns multipart file fields with form semantics instead of body. + if r.request != nil && shared.IsMultipartContentType(r.request.Header.Get("Content-Type")) && rType != nil { + // Parse/seed multipart values only once + r.once.Do(func() { r.seedFormFromMultipart() }) + if r.request.MultipartForm != nil { + // []*multipart.FileHeader + if rType.Kind() == reflect.Slice && rType.Elem() == reflect.TypeOf((*multipart.FileHeader)(nil)) { + files := r.request.MultipartForm.File[name] + if len(files) == 0 { + return nil, false, nil + } + return files, true, nil + } + // *multipart.FileHeader + if rType == reflect.TypeOf((*multipart.FileHeader)(nil)) { + files := r.request.MultipartForm.File[name] + if len(files) == 0 { + return nil, false, nil + } + return files[0], true, nil + } + } + } + values, ok := r.form.Lookup(name) if !ok { if r.request == nil { return nil, false, nil } + // If multipart, seed from multipart and avoid FormValue fallback + if shared.IsMultipartContentType(r.request.Header.Get("Content-Type")) { + r.once.Do(func() { r.seedFormFromMultipart() }) + if values, ok = r.form.Lookup(name); ok { + if len(values) > 1 { + return values, true, nil + } + return r.form.Get(name), true, nil + } + return nil, false, nil + } + // Non-multipart: use standard FormValue fallback r.form.Mutex().Lock() defer r.form.Mutex().Unlock() value := r.request.FormValue(name) @@ -49,3 +96,36 @@ func NewForm(opts ...Option) (kind.Locator, error) { var ret = &Form{form: options.Form, request: options.request} return ret, nil } + +// seedFormFromMultipart parses multipart/form-data (if needed) and copies textual values to the shared form +func (r *Form) seedFormFromMultipart() { + if r.request == nil || r.form == nil { + return + } + if r.request.MultipartForm == nil && len(r.form.Values) == 0 { + // Only ParseMultipartForm for form-data; other multipart types aren't + // supported by ParseMultipartForm. If the shared form already has + // values, treat it as authoritative and avoid parsing. + ct := r.request.Header.Get("Content-Type") + if ct != "" { + if mediaType, _, err := mime.ParseMediaType(ct); err == nil && shared.IsFormData(mediaType) { + // Use the same default memory threshold as Body locator + const maxMultipartMemory = 32 << 20 // 32 MiB + _ = r.request.ParseMultipartForm(maxMultipartMemory) + } + } + } + if r.request.MultipartForm == nil { + return + } + if len(r.request.Form) == 0 { + r.request.Form = url.Values{} + } + for k, vs := range r.request.MultipartForm.Value { + if len(vs) == 0 { + continue + } + r.form.Set(k, vs...) + r.request.Form[k] = vs + } +} diff --git a/view/state/kind/locator/generator.go b/view/state/kind/locator/generator.go index 1583357f8..f020b40ff 100644 --- a/view/state/kind/locator/generator.go +++ b/view/state/kind/locator/generator.go @@ -4,6 +4,7 @@ import ( "context" "github.com/google/uuid" "github.com/viant/datly/view/state/kind" + "reflect" "strings" "time" ) @@ -14,7 +15,7 @@ func (v *Generator) Names() []string { return nil } -func (v *Generator) Value(ctx context.Context, name string) (interface{}, bool, error) { +func (v *Generator) Value(ctx context.Context, _ reflect.Type, name string) (interface{}, bool, error) { switch strings.ToLower(name) { case "nil": return nil, true, nil diff --git a/view/state/kind/locator/header.go b/view/state/kind/locator/header.go index e6a8135e2..c2fd3962f 100644 --- a/view/state/kind/locator/header.go +++ b/view/state/kind/locator/header.go @@ -5,6 +5,7 @@ import ( "fmt" "github.com/viant/datly/view/state/kind" "net/http" + "reflect" ) type Header struct { @@ -20,7 +21,7 @@ func (q *Header) Names() []string { return result } -func (q *Header) Value(ctx context.Context, name string) (interface{}, bool, error) { +func (q *Header) Value(ctx context.Context, _ reflect.Type, name string) (interface{}, bool, error) { value, ok := q.header[name] if !ok { return nil, false, nil diff --git a/view/state/kind/locator/http.go b/view/state/kind/locator/http.go index a2becfe56..62c9ed697 100644 --- a/view/state/kind/locator/http.go +++ b/view/state/kind/locator/http.go @@ -4,10 +4,12 @@ import ( "bytes" "context" "fmt" - "github.com/viant/datly/view/state/kind" "io" "net/http" + "reflect" "strings" + + "github.com/viant/datly/view/state/kind" ) type HttpRequest struct { @@ -19,7 +21,7 @@ func (p *HttpRequest) Names() []string { return nil } -func (p *HttpRequest) Value(ctx context.Context, name string) (interface{}, bool, error) { +func (p *HttpRequest) Value(ctx context.Context, _ reflect.Type, name string) (interface{}, bool, error) { request := p.request if p.request == nil { var err error diff --git a/view/state/kind/locator/object.go b/view/state/kind/locator/object.go index f1c066718..ced3fe835 100644 --- a/view/state/kind/locator/object.go +++ b/view/state/kind/locator/object.go @@ -6,6 +6,7 @@ import ( "github.com/viant/datly/view/state" "github.com/viant/datly/view/state/kind" "github.com/viant/structology" + "reflect" ) type Object struct { @@ -19,7 +20,7 @@ func (p *Object) Names() []string { return nil } -func (p *Object) Value(ctx context.Context, names string) (interface{}, bool, error) { +func (p *Object) Value(ctx context.Context, _ reflect.Type, names string) (interface{}, bool, error) { parameter := p.matchByLocation(names) if parameter == nil { return nil, false, fmt.Errorf("failed to match parameter by location: %v", names) diff --git a/view/state/kind/locator/options.go b/view/state/kind/locator/options.go index d591e8666..18be24842 100644 --- a/view/state/kind/locator/options.go +++ b/view/state/kind/locator/options.go @@ -5,6 +5,7 @@ import ( "net/http" "net/url" "reflect" + "sync" "github.com/viant/datly/gateway/router/marshal/config" "github.com/viant/datly/gateway/router/marshal/json" @@ -21,12 +22,14 @@ import ( // Options represents locator options type ( Options struct { - request *http.Request - Form *hstate.Form - Path map[string]string - Query url.Values - Header http.Header - Body []byte + mu sync.RWMutex + request *http.Request + Form *hstate.Form + QuerySelectors hstate.QuerySelectors + Path map[string]string + Query url.Values + Header http.Header + Body []byte fromError error Parent *KindLocator @@ -56,6 +59,9 @@ type ( ) func (o Options) LookupParameters(name string) *state.Parameter { + o.mu.RLock() + defer o.mu.RUnlock() + if len(o.InputParameters) > 0 { if ret, ok := o.InputParameters[name]; ok { return ret @@ -70,10 +76,17 @@ func (o Options) LookupParameters(name string) *state.Parameter { } func (o *Options) GetRequest() (*http.Request, error) { - return shared.CloneHTTPRequest(o.request) + o.mu.RLock() + req := o.request + o.mu.RUnlock() + + return shared.CloneHTTPRequest(req) } func (o *Options) UnmarshalFunc() Unmarshal { + o.mu.Lock() + defer o.mu.Unlock() + if o.Unmarshal != nil { return o.Unmarshal } @@ -100,6 +113,9 @@ var defaultURL, _ = url.Parse("http://localhost:8080/") // WithRequest create http requestState option func WithRequest(request *http.Request) Option { return func(o *Options) { + o.mu.Lock() + defer o.mu.Unlock() + ensureValueRequest(request) o.request = request } @@ -117,6 +133,9 @@ func ensureValueRequest(request *http.Request) { // WithCustom creates custom options func WithCustom(options ...interface{}) Option { return func(o *Options) { + o.mu.Lock() + defer o.mu.Unlock() + o.Custom = options } } @@ -124,6 +143,9 @@ func WithCustom(options ...interface{}) Option { // WithURIPattern create Path pattern requestState func WithURIPattern(URI string) Option { return func(o *Options) { + o.mu.Lock() + defer o.mu.Unlock() + o.URIPattern = URI } } @@ -131,6 +153,9 @@ func WithURIPattern(URI string) Option { // WithBodyType create Body Type option func WithBodyType(rType reflect.Type) Option { return func(o *Options) { + o.mu.Lock() + defer o.mu.Unlock() + o.BodyType = rType } } @@ -138,6 +163,9 @@ func WithBodyType(rType reflect.Type) Option { // WithUnmarshal creates with unmarshal options func WithUnmarshal(fn func([]byte, interface{}) error) Option { return func(o *Options) { + o.mu.Lock() + defer o.mu.Unlock() + o.Unmarshal = fn } } @@ -145,6 +173,9 @@ func WithUnmarshal(fn func([]byte, interface{}) error) Option { // WithParent creates with parent options func WithParent(locators *KindLocator) Option { return func(o *Options) { + o.mu.Lock() + defer o.mu.Unlock() + o.Parent = locators } } @@ -152,12 +183,18 @@ func WithParent(locators *KindLocator) Option { // WithParameterLookup creates with parameter options func WithParameterLookup(lookupFn ParameterLookup) Option { return func(o *Options) { + o.mu.Lock() + defer o.mu.Unlock() + o.ParameterLookup = lookupFn } } func WithIOConfig(config *config.IOConfig) Option { return func(o *Options) { + o.mu.Lock() + defer o.mu.Unlock() + o.IOConfig = config } } @@ -165,6 +202,9 @@ func WithIOConfig(config *config.IOConfig) Option { // WithInputParameters creates with parameter options func WithInputParameters(parameters state.NamedParameters) Option { return func(o *Options) { + o.mu.Lock() + defer o.mu.Unlock() + if len(o.resourceConstants) == 0 { o.resourceConstants = make(map[string]interface{}) } @@ -181,15 +221,30 @@ func WithInputParameters(parameters state.NamedParameters) Option { } } +func WithQuerySelectors(selectors hstate.QuerySelectors) Option { + return func(o *Options) { + o.mu.Lock() + defer o.mu.Unlock() + + o.QuerySelectors = selectors + } +} + // WithPathParameters create with path parameters options func WithPathParameters(parameters map[string]string) Option { return func(o *Options) { + o.mu.Lock() + defer o.mu.Unlock() + o.Path = parameters } } func WithReadInto(fn ReadInto) Option { return func(o *Options) { + o.mu.Lock() + defer o.mu.Unlock() + o.ReadInto = fn } } @@ -197,6 +252,9 @@ func WithReadInto(fn ReadInto) Option { // WithViews returns with views options func WithViews(views view.NamedViews) Option { return func(o *Options) { + o.mu.Lock() + defer o.mu.Unlock() + o.Views = views } } @@ -204,12 +262,18 @@ func WithViews(views view.NamedViews) Option { // WithState returns with satte options func WithState(state *structology.State) Option { return func(o *Options) { + o.mu.Lock() + defer o.mu.Unlock() + o.State = state } } func WithOutputParameters(parameters state.Parameters) Option { return func(o *Options) { + o.mu.Lock() + defer o.mu.Unlock() + o.OutputParameters = parameters.Index() } } @@ -217,6 +281,9 @@ func WithOutputParameters(parameters state.Parameters) Option { // WithDispatcher returns options to set dispatcher func WithDispatcher(dispatcher contract.Dispatcher) Option { return func(o *Options) { + o.mu.Lock() + defer o.mu.Unlock() + o.Dispatcher = dispatcher } } @@ -224,6 +291,9 @@ func WithDispatcher(dispatcher contract.Dispatcher) Option { // WithView returns options to set view func WithView(aView *view.View) Option { return func(o *Options) { + o.mu.Lock() + defer o.mu.Unlock() + o.View = aView } } @@ -231,6 +301,9 @@ func WithView(aView *view.View) Option { // WithForm return form option func WithForm(form *hstate.Form) Option { return func(o *Options) { + o.mu.Lock() + defer o.mu.Unlock() + if o.Form == nil { o.Form = form } else if form != nil { @@ -242,6 +315,9 @@ func WithForm(form *hstate.Form) Option { // WithQuery return query parameters option func WithQuery(parameters url.Values) Option { return func(o *Options) { + o.mu.Lock() + defer o.mu.Unlock() + if o.Query == nil { o.Query = parameters } else { @@ -254,6 +330,9 @@ func WithQuery(parameters url.Values) Option { func WithLogger(logger logger.Logger) Option { return func(o *Options) { + o.mu.Lock() + defer o.mu.Unlock() + o.Logger = logger } } @@ -261,6 +340,9 @@ func WithLogger(logger logger.Logger) Option { // WithQueryParameter return query parameter option func WithQueryParameter(name, value string) Option { return func(o *Options) { + o.mu.Lock() + defer o.mu.Unlock() + if o.Query == nil { o.Query = make(url.Values) } @@ -271,6 +353,9 @@ func WithQueryParameter(name, value string) Option { // WithHeader return header option func WithHeader(name, value string) Option { return func(o *Options) { + o.mu.Lock() + defer o.mu.Unlock() + if o.Header == nil { o.Header = make(http.Header) } @@ -281,6 +366,9 @@ func WithHeader(name, value string) Option { // WithHeaders return headers option func WithHeaders(header http.Header) Option { return func(o *Options) { + o.mu.Lock() + defer o.mu.Unlock() + if o.Header == nil { o.Header = header } @@ -293,6 +381,9 @@ func WithHeaders(header http.Header) Option { // WithMetrics return metrics option func WithMetrics(metrics response.Metrics) Option { return func(o *Options) { + o.mu.Lock() + defer o.mu.Unlock() + o.Metrics = metrics } } @@ -300,6 +391,9 @@ func WithMetrics(metrics response.Metrics) Option { // WithResource return resource option func WithResource(resource *view.Resource) Option { return func(o *Options) { + o.mu.Lock() + defer o.mu.Unlock() + o.Resource = resource } } @@ -307,6 +401,9 @@ func WithResource(resource *view.Resource) Option { // WithConstants return Constants option func WithConstants(constants map[string]interface{}) Option { return func(o *Options) { + o.mu.Lock() + defer o.mu.Unlock() + o.Constants = constants } } @@ -314,6 +411,9 @@ func WithConstants(constants map[string]interface{}) Option { // WithTypes return types option func WithTypes(types ...*state.Type) Option { return func(o *Options) { + o.mu.Lock() + defer o.mu.Unlock() + o.Types = types } } diff --git a/view/state/kind/locator/parameter.go b/view/state/kind/locator/parameter.go index b46f54ed0..35578c13c 100644 --- a/view/state/kind/locator/parameter.go +++ b/view/state/kind/locator/parameter.go @@ -5,6 +5,7 @@ import ( "fmt" "github.com/viant/datly/view/state" "github.com/viant/datly/view/state/kind" + "reflect" ) type Parameter struct { @@ -16,7 +17,7 @@ func (p *Parameter) Names() []string { return nil } -func (p *Parameter) Value(ctx context.Context, name string) (interface{}, bool, error) { +func (p *Parameter) Value(ctx context.Context, _ reflect.Type, name string) (interface{}, bool, error) { parameter, ok := p.Parameters[name] if !ok { return nil, false, fmt.Errorf("uknonw parameter: %s", name) diff --git a/view/state/kind/locator/path.go b/view/state/kind/locator/path.go index eecabec7c..79a8af548 100644 --- a/view/state/kind/locator/path.go +++ b/view/state/kind/locator/path.go @@ -5,6 +5,7 @@ import ( "fmt" "github.com/viant/datly/view/state/kind" "github.com/viant/toolbox" + "reflect" ) type Path struct { @@ -20,7 +21,7 @@ func (v *Path) Names() []string { return result } -func (v *Path) Value(ctx context.Context, name string) (interface{}, bool, error) { +func (v *Path) Value(ctx context.Context, _ reflect.Type, name string) (interface{}, bool, error) { if name == "" { return v.path, true, nil } diff --git a/view/state/kind/locator/query.go b/view/state/kind/locator/query.go index 605f1b1a7..b532215aa 100644 --- a/view/state/kind/locator/query.go +++ b/view/state/kind/locator/query.go @@ -7,6 +7,7 @@ import ( "github.com/viant/xdatly/handler/exec" "net/http" "net/url" + "reflect" ) type Query struct { @@ -23,7 +24,7 @@ func (q *Query) Names() []string { return result } -func (q *Query) Value(ctx context.Context, name string) (interface{}, bool, error) { +func (q *Query) Value(ctx context.Context, _ reflect.Type, name string) (interface{}, bool, error) { if name == "" { return q.rawQuery, true, nil } diff --git a/view/state/kind/locator/repeated.go b/view/state/kind/locator/repeated.go index d40972b4f..2f19e91ae 100644 --- a/view/state/kind/locator/repeated.go +++ b/view/state/kind/locator/repeated.go @@ -3,12 +3,13 @@ package locator import ( "context" "fmt" - "github.com/viant/datly/view/state" - "github.com/viant/datly/view/state/kind" - "github.com/viant/xunsafe" "reflect" "sync" "sync/atomic" + + "github.com/viant/datly/view/state" + "github.com/viant/datly/view/state/kind" + "github.com/viant/xunsafe" ) type Repeated struct { @@ -27,7 +28,7 @@ func (p *Repeated) Names() []string { return nil } -func (p *Repeated) Value(ctx context.Context, names string) (interface{}, bool, error) { +func (p *Repeated) Value(ctx context.Context, _ reflect.Type, names string) (interface{}, bool, error) { parameter := p.matchByLocation(names) if parameter == nil { return nil, false, fmt.Errorf("failed to match parameter by location: %v", names) diff --git a/view/state/kind/locator/state.go b/view/state/kind/locator/state.go index bbd44cd4c..503dba3f5 100644 --- a/view/state/kind/locator/state.go +++ b/view/state/kind/locator/state.go @@ -5,6 +5,7 @@ import ( "fmt" "github.com/viant/datly/view/state/kind" "github.com/viant/structology" + "reflect" ) type State struct { @@ -14,7 +15,7 @@ type State struct { func (p *State) Names() []string { return nil } -func (p *State) Value(ctx context.Context, name string) (interface{}, bool, error) { +func (p *State) Value(ctx context.Context, _ reflect.Type, name string) (interface{}, bool, error) { _, err := p.State.Selector(name) if err != nil { return nil, false, nil diff --git a/view/state/kind/locator/transient.go b/view/state/kind/locator/transient.go index 0fde22e6e..72ac34cca 100644 --- a/view/state/kind/locator/transient.go +++ b/view/state/kind/locator/transient.go @@ -3,6 +3,7 @@ package locator import ( "context" "github.com/viant/datly/view/state/kind" + "reflect" ) type Transient struct{} @@ -11,7 +12,7 @@ func (v *Transient) Names() []string { return nil } -func (v *Transient) Value(ctx context.Context, name string) (interface{}, bool, error) { +func (v *Transient) Value(ctx context.Context, _ reflect.Type, name string) (interface{}, bool, error) { if name == "" { return nil, false, nil } diff --git a/view/state/parameter.go b/view/state/parameter.go index ac6461b57..a1ded9294 100644 --- a/view/state/parameter.go +++ b/view/state/parameter.go @@ -3,6 +3,11 @@ package state import ( "context" "fmt" + "net/http" + "reflect" + "strconv" + "strings" + "github.com/viant/datly/internal/setter" "github.com/viant/datly/shared" "github.com/viant/datly/utils/types" @@ -11,10 +16,6 @@ import ( "github.com/viant/structology" "github.com/viant/xreflect" "github.com/viant/xunsafe" - "net/http" - "reflect" - "strconv" - "strings" ) type ( @@ -512,7 +513,12 @@ func (p *Parameter) initCodec(resource Resource) error { if p.Output == nil { return nil } - + stateTag, _ := tags.ParseStateTags(reflect.StructTag(p.Tag), resource.EmbedFS()) + if stateTag != nil { + if stateTag.Codec != nil && stateTag.Codec.Body != "" { + p.Output.Body = stateTag.Codec.Body + } + } inputType := p.Schema.Type() if err := p.Output.Init(resource, inputType); err != nil { return err @@ -520,10 +526,9 @@ func (p *Parameter) initCodec(resource Resource) error { if p.Output.Schema == nil { return nil } - if !p.Output.Schema.IsNamed() { fieldTag := reflect.StructTag(p.Tag) - if stateTag, _ := tags.ParseStateTags(fieldTag, nil); stateTag != nil { + if stateTag != nil { stateTag.TypeName = SanitizeTypeName(p.Output.Schema.Name) p.Tag = string(stateTag.UpdateTag(fieldTag)) } diff --git a/view/state/parameters.go b/view/state/parameters.go index 0b503b84a..e464ce142 100644 --- a/view/state/parameters.go +++ b/view/state/parameters.go @@ -2,6 +2,10 @@ package state import ( "fmt" + "net/http" + "reflect" + "strings" + "github.com/viant/datly/internal/setter" "github.com/viant/datly/shared" "github.com/viant/datly/utils/types" @@ -13,9 +17,6 @@ import ( "github.com/viant/velty" "github.com/viant/xreflect" "github.com/viant/xunsafe" - "net/http" - "reflect" - "strings" ) const ( @@ -219,9 +220,17 @@ func (p Parameters) Groups() []Parameters { } func (p Parameters) SetLiterals(state *structology.State) (err error) { + if state == nil { + return nil + } + stateType := state.Type() for _, parameter := range p.FilterByKind(KindConst) { - if parameter._selector == nil { - parameter._selector = state.Type().Lookup(parameter.Name) + // Selector must be resolved against the provided state type. + // Caching it on the parameter is unsafe because the same parameter instance + // can be used with multiple dynamically-generated state types (e.g. during translation). + selector := stateType.Lookup(parameter.Name) + if selector == nil { + return fmt.Errorf("failed to lookup selector for const parameter %q", parameter.Name) } if parameter.Value == nil { switch parameter.Schema.rType.Kind() { @@ -236,7 +245,7 @@ func (p Parameters) SetLiterals(state *structology.State) (err error) { } } - if err = parameter._selector.SetValue(state.Pointer(), parameter.Value); err != nil { + if err = selector.SetValue(state.Pointer(), parameter.Value); err != nil { return err } } @@ -392,7 +401,9 @@ func (p *Parameter) buildField(pkgPath string, lookupType xreflect.LookupType) ( if err != nil { rType, err = types.LookupType(lookupType, schema.DataType, xreflect.WithPackage(pkgPath)) if err != nil { - return structField, markerField, fmt.Errorf("failed to detect parmater '%v' type for: %v %w", p.Name, schema.TypeName(), err) + // Keep unresolved custom parameter types as dynamic `interface{}` so + // scan/planning can continue while preserving declared schema metadata. + rType = reflect.TypeOf((*interface{})(nil)).Elem() } } schema.rType = rType diff --git a/view/state/parameters_set_literals_test.go b/view/state/parameters_set_literals_test.go new file mode 100644 index 000000000..f94d8f5ee --- /dev/null +++ b/view/state/parameters_set_literals_test.go @@ -0,0 +1,64 @@ +package state + +import ( + "reflect" + "testing" + + "github.com/viant/structology" +) + +func TestParameters_SetLiterals_DoesNotReuseSelectorAcrossStateTypes(t *testing.T) { + const ( + paramName = "X" + dummyName = "Dummy" + dummyValue = 12345 + constValue = true + constSource = "value" + ) + + param := &Parameter{ + Name: paramName, + In: &Location{Kind: KindConst, Name: constSource}, + Value: constValue, + Schema: &Schema{ + rType: reflect.TypeOf(true), + }, + } + params := Parameters{param} + + type1 := reflect.StructOf([]reflect.StructField{ + {Name: dummyName, Type: reflect.TypeOf(int(0))}, + {Name: paramName, Type: reflect.TypeOf(true)}, + }) + state1 := structology.NewStateType(type1).NewState() + if err := state1.SetInt(dummyName, dummyValue); err != nil { + t.Fatalf("failed to init %s: %v", dummyName, err) + } + if err := params.SetLiterals(state1); err != nil { + t.Fatalf("SetLiterals(type1) failed: %v", err) + } + if got, err := state1.Bool(paramName); err != nil || got != constValue { + t.Fatalf("type1 %s: got=%v err=%v, want=%v", paramName, got, err, constValue) + } + if got, err := state1.Value(dummyName); err != nil || got.(int) != dummyValue { + t.Fatalf("type1 %s: got=%v err=%v, want=%v", dummyName, got, err, dummyValue) + } + + type2 := reflect.StructOf([]reflect.StructField{ + {Name: paramName, Type: reflect.TypeOf(true)}, + {Name: dummyName, Type: reflect.TypeOf(int(0))}, + }) + state2 := structology.NewStateType(type2).NewState() + if err := state2.SetInt(dummyName, dummyValue); err != nil { + t.Fatalf("failed to init %s: %v", dummyName, err) + } + if err := params.SetLiterals(state2); err != nil { + t.Fatalf("SetLiterals(type2) failed: %v", err) + } + if got, err := state2.Bool(paramName); err != nil || got != constValue { + t.Fatalf("type2 %s: got=%v err=%v, want=%v", paramName, got, err, constValue) + } + if got, err := state2.Value(dummyName); err != nil || got.(int) != dummyValue { + t.Fatalf("type2 %s: got=%v err=%v, want=%v", dummyName, got, err, dummyValue) + } +} diff --git a/view/state/resource.go b/view/state/resource.go index 7c39ad857..b80f1fe1d 100644 --- a/view/state/resource.go +++ b/view/state/resource.go @@ -2,6 +2,7 @@ package state import ( "context" + "embed" "github.com/viant/xdatly/codec" "github.com/viant/xreflect" ) @@ -22,5 +23,9 @@ type ( ExpandSubstitutes(text string) string ReverseSubstitutes(text string) string + + EmbedFS() *embed.FS + + SetFSEmbedder(embedder *FSEmbedder) } ) diff --git a/view/state/type.go b/view/state/type.go index d64c07281..262a83959 100644 --- a/view/state/type.go +++ b/view/state/type.go @@ -4,6 +4,10 @@ import ( "context" "embed" "fmt" + "reflect" + "strings" + "unicode" + "github.com/viant/datly/internal/setter" "github.com/viant/datly/utils/types" "github.com/viant/datly/view/extension" @@ -11,9 +15,6 @@ import ( "github.com/viant/structology" "github.com/viant/tagly/format/text" "github.com/viant/xreflect" - "reflect" - "strings" - "unicode" ) type ( @@ -110,6 +111,10 @@ func (t *Type) ensureEmbedder(reflect.Type) { t.embedder = NewFSEmbedder(nil) } t.embedder.SetType(reflect.TypeOf(t)) + if t.resource != nil && t.resource.EmbedFS() == nil { + t.resource.SetFSEmbedder(t.embedder) + } + } func (t *Type) adjustConstants() { @@ -261,11 +266,16 @@ func BuildSchema(field *reflect.StructField, pTag *tags.Parameter, result *Param isSlice = true rawType = rawType.Elem() } + isPtr := false if rawType.Kind() == reflect.Ptr { rawType = rawType.Elem() + isPtr = true } rawName := rawType.Name() + if isPtr { + rawName = "*" + rawName + } if pTag.Cardinality != "" { result.ensureSchema() result.Schema.Cardinality = Cardinality(pTag.Cardinality) diff --git a/view/state/types.go b/view/state/types.go index 292ad4147..f301290ea 100644 --- a/view/state/types.go +++ b/view/state/types.go @@ -11,6 +11,9 @@ type Types struct { } func (c *Types) Lookup(p reflect.Type) (*Type, bool) { + if len(c.types) == 0 { + return nil, false + } c.RWMutex.RLock() ret, ok := c.types[p] c.RWMutex.RUnlock() diff --git a/view/tags/codec.go b/view/tags/codec.go index a16be7509..d606ed717 100644 --- a/view/tags/codec.go +++ b/view/tags/codec.go @@ -1,9 +1,9 @@ package tags import ( - "fmt" - "github.com/viant/tagly/tags" "strings" + + "github.com/viant/tagly/tags" ) // CodecTag codec tag @@ -31,10 +31,11 @@ func (t *Tag) updatedCodec(key string, value string) (err error) { } tag.Body = string(data) default: + expr := key if value != "" { - return fmt.Errorf("invalid argument %s", value) + expr += " =" + value } - tag.Arguments = append(tag.Arguments, key) + tag.Arguments = append(tag.Arguments, expr) } return err } diff --git a/view/tags/parameter.go b/view/tags/parameter.go index 7acd1c563..de777c815 100644 --- a/view/tags/parameter.go +++ b/view/tags/parameter.go @@ -88,7 +88,7 @@ func (p *Parameter) Tag() *tags.Tag { if *p.Cacheable { value = "true" } - appendNonEmpty(builder, "cachable", value) + appendNonEmpty(builder, "cacheable", value) } if p.Cardinality == "One" { diff --git a/view/tags/parameter_test.go b/view/tags/parameter_test.go index 6cb562220..aa27bb52e 100644 --- a/view/tags/parameter_test.go +++ b/view/tags/parameter_test.go @@ -24,7 +24,7 @@ func TestTag_updateParameter(t *testing.T) { { description: "async Parameter", tag: `parameter:"p1,kind=query,in=qp1,scope=async"`, - expect: &Parameter{Name: "p1", Kind: "query", In: "qp1", Scope: "myscope"}, + expect: &Parameter{Name: "p1", Kind: "query", In: "qp1", Scope: "async"}, }, } diff --git a/view/tags/parser.go b/view/tags/parser.go index aa4c407a0..283d1eded 100644 --- a/view/tags/parser.go +++ b/view/tags/parser.go @@ -4,14 +4,15 @@ import ( "context" "embed" "fmt" + "reflect" + "strings" + "github.com/viant/afs" "github.com/viant/afs/storage" "github.com/viant/afs/url" "github.com/viant/tagly/format" "github.com/viant/tagly/tags" "github.com/viant/xreflect" - "reflect" - "strings" ) // ValueTag represents default value tag diff --git a/view/tags/predicate.go b/view/tags/predicate.go index e66ef5e76..956a2bc5b 100644 --- a/view/tags/predicate.go +++ b/view/tags/predicate.go @@ -2,9 +2,10 @@ package tags import ( "fmt" - "github.com/viant/tagly/tags" "strconv" "strings" + + "github.com/viant/tagly/tags" ) // PredicateTag Predicate tag @@ -70,10 +71,11 @@ func (t *Tag) updatedPredicate(key string, value string) (err error) { return fmt.Errorf("invalid predicate ensure: %s %w", value, err) } default: + expr := key if value != "" { - return fmt.Errorf("invalid argument %s", value) + expr = key + "=" + value } - tag.Arguments = append(tag.Arguments, key) + tag.Arguments = append(tag.Arguments, expr) } return err } diff --git a/view/tags/view_test.go b/view/tags/view_test.go index e85c3f2d3..1127cf66c 100644 --- a/view/tags/view_test.go +++ b/view/tags/view_test.go @@ -29,7 +29,7 @@ func TestTag_updateView(t *testing.T) { description: "basic view", tag: `view:"foo,connector=dev" sql:"uri=testdata/foo.sql"`, expectView: &View{Name: "foo", Connector: "dev"}, - expectSQL: ViewSQL{SQL: "SELECT * FROM FOO"}, + expectSQL: ViewSQL{SQL: "SELECT * FROM FOO", URI: "testdata/foo.sql"}, expectTag: "foo,connector=dev", }, { diff --git a/view/template.go b/view/template.go index c0b523b6c..ae4308723 100644 --- a/view/template.go +++ b/view/template.go @@ -231,6 +231,10 @@ func (t *Template) EvaluateState(ctx context.Context, parameterState *structolog } func (t *Template) EvaluateStateWithSession(ctx context.Context, parameterState *structology.State, parentParam *expand.ViewContext, batchData *BatchData, sess *extension.Session, options ...interface{}) (*expand.State, error) { + // Ensure parameter state is initialized when absent, but never override an existing one. + if parameterState == nil && t.stateType != nil { + parameterState = t.stateType.NewState() + } var expander expand.Expander var dataUnit *expand.DataUnit for _, option := range options { @@ -372,8 +376,7 @@ func (t *Template) Expand(placeholders *[]interface{}, SQL string, selector *Sta if value.Key == "?" { placeholder, err := sanitized.Next() if err != nil { - return "", fmt.Errorf("failed to get placeholder: %w, SQL: %v, values: %v\n", err, SQL, values) - + return "", fmt.Errorf("failed to get placeholder: %w, SQL: %v, values: %+v\n", err, SQL, values) } *placeholders = append(*placeholders, placeholder) continue diff --git a/view/view.go b/view/view.go index 42bcc1906..b297489f1 100644 --- a/view/view.go +++ b/view/view.go @@ -4,6 +4,12 @@ import ( "context" "database/sql" "fmt" + "net/http" + "path" + "reflect" + "strings" + "time" + "github.com/viant/afs/url" "github.com/viant/datly/gateway/router/marshal" "github.com/viant/datly/internal/setter" @@ -23,11 +29,6 @@ import ( "github.com/viant/tagly/format/text" "github.com/viant/xreflect" "github.com/viant/xunsafe" - "net/http" - "path" - "reflect" - "strings" - "time" ) const ( @@ -154,15 +155,16 @@ func (v *View) Context(ctx context.Context) context.Context { // Constraints configure what can be selected by Statelet // For each _field, default value is `false` type Constraints struct { - Criteria bool - OrderBy bool - Limit bool - Offset bool - Projection bool //enables columns projection from client (default ${NS}_fields= query param) - Filterable []string - SQLMethods []*Method `json:",omitempty"` - _sqlMethods map[string]*Method - Page *bool + Criteria bool + OrderBy bool + OrderByColumn map[string]string + Limit bool + Offset bool + Projection bool //enables columns projection from client (default ${NS}_fields= query param) + Filterable []string + SQLMethods []*Method `json:",omitempty"` + _sqlMethods map[string]*Method + Page *bool } func (v *View) Resource() state.Resource { @@ -400,6 +402,10 @@ func (v *View) inheritRelationsFromTag(schema *state.Schema) error { refViewOptions = append(refViewOptions, WithCache(aCache)) } + if viewTag.Limit != nil { + viewOptions = append(viewOptions, WithLimit(viewTag.Limit)) + } + if viewTag.PublishParent { refViewOptions = append(refViewOptions, WithViewPublishParent(viewTag.PublishParent)) } @@ -473,6 +479,9 @@ func WithLimit(limit *int) Option { } view.Selector.Constraints.Limit = true view.Selector.Limit = *limit + if limit != nil { + view.Selector.NoLimit = *limit == 0 + } return nil } } @@ -1289,7 +1298,7 @@ func (v *View) markColumnsAsFilterable() error { for _, colName := range v.Selector.Constraints.Filterable { column, err := v._columns.Lookup(colName) if err != nil { - return fmt.Errorf("criteria column %v, on view has not been defined, %w", colName, v.Name, err) + return fmt.Errorf("criteria column %v on view %v has not been defined: %w", colName, v.Name, err) } column.Filterable = true } diff --git a/warmup/cache_test.go b/warmup/cache_test.go index e691c8357..9196529d6 100644 --- a/warmup/cache_test.go +++ b/warmup/cache_test.go @@ -2,17 +2,21 @@ package warmup import ( "context" + "os" + "path" + "testing" + "github.com/stretchr/testify/assert" - "github.com/viant/afs" - "github.com/viant/datly/gateway/router" "github.com/viant/datly/internal/tests" "github.com/viant/datly/service/reader" "github.com/viant/datly/view" - "path" - "testing" ) func TestPopulateCache(t *testing.T) { + if os.Getenv("DATLY_RUN_WARMUP_TESTS") == "" { + t.Skip("set DATLY_RUN_WARMUP_TESTS=1 to run warmup integration test") + } + testCases := []struct { description string URL string @@ -59,14 +63,14 @@ func TestPopulateCache(t *testing.T) { resourcePath := path.Join("testdata", testCase.URL, "resource.yaml") - resource, err := router.NewResourceFromURL(context.TODO(), afs.New(), resourcePath, false) + resource, err := view.NewResourceFromURL(context.TODO(), resourcePath, nil, nil) if !assert.Nil(t, err, testCase.description) { continue } var views []*view.View - for _, route := range resource.Routes { - views = append(views, route.View) + for _, item := range resource.Views { + views = append(views, item) } inserted, err := PopulateCache(views) @@ -100,7 +104,7 @@ func checkIfCached(t *testing.T, cache *view.Cache, ctx context.Context, testCas builder := reader.NewBuilder() for _, cacheInput := range input { - build, err := builder.CacheSQL(aView, cacheInput.Selector) + build, err := builder.CacheSQL(ctx, aView, cacheInput.Selector) if err != nil { return err } @@ -116,7 +120,7 @@ func checkIfCached(t *testing.T, cache *view.Cache, ctx context.Context, testCas } if cacheInput.IndexMeta && aView.Template.Summary != nil { - metaIndex, err := builder.CacheMetaSQL(aView, cacheInput.Selector, &view.BatchData{ + metaIndex, err := builder.CacheMetaSQL(ctx, aView, cacheInput.Selector, &view.BatchData{ ValuesBatch: testCase.metaIndexed, Values: testCase.metaIndexed, }, nil, nil)