diff --git a/pkg/datastore/defaults/defaults.go b/pkg/datastore/defaults/defaults.go index ebc33115..ecdc9624 100644 --- a/pkg/datastore/defaults/defaults.go +++ b/pkg/datastore/defaults/defaults.go @@ -4,7 +4,7 @@ const ( TopK int = 10 TextSplitterTokenModel = "gpt-4" - TextSplitterChunkSize = 1024 + TextSplitterChunkSize = 2048 TextSplitterChunkOverlap = 256 TextSplitterTokenEncoding = "cl100k_base" ) diff --git a/pkg/datastore/textsplitter/markdown_rolling/markdown_rolling.go b/pkg/datastore/textsplitter/markdown_rolling/markdown_rolling.go new file mode 100644 index 00000000..5ca717e4 --- /dev/null +++ b/pkg/datastore/textsplitter/markdown_rolling/markdown_rolling.go @@ -0,0 +1,187 @@ +package markdown_rolling + +import ( + "fmt" + "strings" + + "github.com/pkoukk/tiktoken-go" + lcgosplitter "github.com/tmc/langchaingo/textsplitter" +) + +// NewMarkdownTextSplitter creates a new Markdown text splitter. +func NewMarkdownTextSplitter(opts ...Option) (*MarkdownTextSplitter, error) { + options := DefaultOptions() + + for _, opt := range opts { + opt(&options) + } + + var tk *tiktoken.Tiktoken + var err error + if options.EncodingName != "" { + tk, err = tiktoken.GetEncoding(options.EncodingName) + } else { + tk, err = tiktoken.EncodingForModel(options.ModelName) + } + if err != nil { + return nil, fmt.Errorf("couldn't get encoding: %w", err) + } + + tokenSplitter := lcgosplitter.TokenSplitter{ + ChunkSize: options.ChunkSize, + ChunkOverlap: options.ChunkOverlap, + ModelName: options.ModelName, + EncodingName: options.EncodingName, + AllowedSpecial: []string{}, + DisallowedSpecial: []string{"all"}, + } + + return &MarkdownTextSplitter{ + options, + tk, + tokenSplitter, + }, nil +} + +// MarkdownTextSplitter markdown header text splitter. +type MarkdownTextSplitter struct { + Options + *tiktoken.Tiktoken + tokenSplitter lcgosplitter.TokenSplitter +} + +type block struct { + headings []string + lines []string + text string + tokenSize int +} + +func (s *MarkdownTextSplitter) getTokenSize(text string) int { + return len(s.Encode(text, []string{}, []string{"all"})) +} + +func (s *MarkdownTextSplitter) finishBlock(blocks []block, currentBlock block, headingStack []string) ([]block, block, error) { + + for _, header := range headingStack { + if header != "" { + currentBlock.headings = append(currentBlock.headings, header) + } + } + + if len(currentBlock.lines) == 0 && s.IgnoreHeadingOnly { + return blocks, block{}, nil + } + + headingStr := strings.TrimSpace(strings.Join(currentBlock.headings, "\n")) + contentStr := strings.TrimSpace(strings.Join(currentBlock.lines, "\n")) + text := headingStr + "\n" + contentStr + + if len(text) == 0 { + return blocks, block{}, nil + } + + textTokenSize := s.getTokenSize(text) + + if textTokenSize <= s.ChunkSize { + // append new block to free up some space + return append(blocks, block{ + text: text, + tokenSize: textTokenSize, + }), block{}, nil + } + + // If the block is larger than the chunk size, split it + headingTokenSize := s.getTokenSize(headingStr) + + // Split into chunks that leave room for the heading + s.tokenSplitter.ChunkSize = s.ChunkSize - headingTokenSize + + splits, err := s.tokenSplitter.SplitText(contentStr) + if err != nil { + return blocks, block{}, err + } + + for _, split := range splits { + text = headingStr + "\n" + split + blocks = append(blocks, block{ + text: text, + tokenSize: s.getTokenSize(text), + }) + } + + return blocks, block{}, nil + +} + +// SplitText splits text into chunks. +func (s *MarkdownTextSplitter) SplitText(text string) ([]string, error) { + + var ( + headingStack []string + chunks []string + currentChunk block + currentHeadingLevel int = 1 + currentBlock block + + blocks []block + err error + ) + + // Parse markdown line-by-line and build heading-delimited blocks + for _, line := range strings.Split(text, "\n") { + + // Handle header = start a new block + if strings.HasPrefix(line, "#") { + // Finish the previous Block + blocks, currentBlock, err = s.finishBlock(blocks, currentBlock, headingStack) + if err != nil { + return nil, err + } + + // Get the header level + headingLevel := strings.Count(strings.Split(line, " ")[0], "#") - 1 + + headingStack = append(headingStack[:headingLevel], line) + + // Clear the header stack for lower level headers + for j := headingLevel + 1; j < len(headingStack); j++ { + headingStack[j] = "" + } + + // Reset header stack indices between this level and the last seen level, backwards + for j := headingLevel - 1; j > currentHeadingLevel; j-- { + headingStack[j] = "" + } + + currentHeadingLevel = headingLevel + continue + + } + + // If the line is not a header, add it to the current block + currentBlock.lines = append(currentBlock.lines, line) + + } + + // Finish the last block + blocks, currentBlock, err = s.finishBlock(blocks, currentBlock, headingStack) + if err != nil { + return nil, err + } + + // Combine blocks into chunks as close to the target token size as possible + for _, b := range blocks { + if currentChunk.tokenSize+b.tokenSize <= s.ChunkSize { + // Doesn't exceed chunk size, so add to the current chunk + currentChunk.text += "\n" + b.text + currentChunk.tokenSize += b.tokenSize + } else { + // Exceeds chunk size, so start a new chunk + chunks = append(chunks, currentChunk.text) + currentChunk = b + } + } + + return chunks, nil +} diff --git a/pkg/datastore/textsplitter/markdown_rolling/markdown_rolling_test.go b/pkg/datastore/textsplitter/markdown_rolling/markdown_rolling_test.go new file mode 100644 index 00000000..4137cb7c --- /dev/null +++ b/pkg/datastore/textsplitter/markdown_rolling/markdown_rolling_test.go @@ -0,0 +1,85 @@ +package markdown_rolling + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestSplitTextWithBasicMarkdown(t *testing.T) { + splitter := NewMarkdownTextSplitter() + chunks, err := splitter.SplitText("# Heading\n\nThis is a paragraph.") + assert.NoError(t, err) + assert.Equal(t, 1, len(chunks)) + + expected := []string{"# Heading\nThis is a paragraph."} + + assert.Equal(t, expected, chunks) +} + +func TestSplitTextWithOptions(t *testing.T) { + md := ` +# Heading 1 + +some p under h1 + +## Heading 2 +### Heading 3 + +- some +- list +- items + +**bold** + +# 2nd Heading 1 +#### Heading 4 + +some p under h4 +` + + testcases := []struct { + name string + splitter *MarkdownTextSplitter + expected []string + }{ + { + name: "default", + splitter: NewMarkdownTextSplitter(), + expected: []string{ + "# Heading 1\nsome p under h1", + "# Heading 1\n## Heading 2", + "# Heading 1\n## Heading 2\n### Heading 3\n- some\n- list\n- items\n\n**bold**", + "# 2nd Heading 1", + "# 2nd Heading 1\n#### Heading 4\nsome p under h4", + }, + }, + { + name: "ignore_heading_only", + splitter: NewMarkdownTextSplitter(WithIgnoreHeadingOnly(true)), + expected: []string{ + "# Heading 1\nsome p under h1", + "# Heading 1\n## Heading 2\n### Heading 3\n- some\n- list\n- items\n\n**bold**", + "# 2nd Heading 1\n#### Heading 4\nsome p under h4", + }, + }, + { + name: "split_h1_only", + splitter: NewMarkdownTextSplitter(), + expected: []string{ + "# Heading 1\nsome p under h1\n\n## Heading 2\n### Heading 3\n\n- some\n- list\n- items\n\n**bold**", + "# 2nd Heading 1\n#### Heading 4\n\nsome p under h4", + }, + }, + } + + for _, tc := range testcases { + t.Run(tc.name, func(t *testing.T) { + chunks, err := tc.splitter.SplitText(md) + assert.NoError(t, err) + assert.Equal(t, len(tc.expected), len(chunks)) + + assert.Equal(t, tc.expected, chunks) + }) + } +} diff --git a/pkg/datastore/textsplitter/markdown_rolling/options.go b/pkg/datastore/textsplitter/markdown_rolling/options.go new file mode 100644 index 00000000..8887919d --- /dev/null +++ b/pkg/datastore/textsplitter/markdown_rolling/options.go @@ -0,0 +1,69 @@ +package markdown_rolling + +import ( + "github.com/gptscript-ai/knowledge/pkg/datastore/defaults" + lcgosplitter "github.com/tmc/langchaingo/textsplitter" +) + +// Options is a struct that contains options for a text splitter. +type Options struct { + ChunkSize int + ChunkOverlap int + Separators []string + KeepSeparator bool + ModelName string + EncodingName string + SecondSplitter lcgosplitter.TextSplitter + + IgnoreHeadingOnly bool // Ignore chunks that only contain headings +} + +// DefaultOptions returns the default options for all text splitter. +func DefaultOptions() Options { + return Options{ + ChunkSize: defaults.TextSplitterChunkSize, + ChunkOverlap: defaults.TextSplitterChunkOverlap, + + ModelName: defaults.TextSplitterTokenModel, + EncodingName: defaults.TextSplitterTokenEncoding, + + IgnoreHeadingOnly: true, + } +} + +// Option is a function that can be used to set options for a text splitter. +type Option func(*Options) + +// WithChunkSize sets the chunk size for a text splitter. +func WithChunkSize(chunkSize int) Option { + return func(o *Options) { + o.ChunkSize = chunkSize + } +} + +// WithChunkOverlap sets the chunk overlap for a text splitter. +func WithChunkOverlap(chunkOverlap int) Option { + return func(o *Options) { + o.ChunkOverlap = chunkOverlap + } +} + +// WithModelName sets the model name for a text splitter. +func WithModelName(modelName string) Option { + return func(o *Options) { + o.ModelName = modelName + } +} + +// WithEncodingName sets the encoding name for a text splitter. +func WithEncodingName(encodingName string) Option { + return func(o *Options) { + o.EncodingName = encodingName + } +} + +func WithIgnoreHeadingOnly(ignoreHeadingOnly bool) Option { + return func(o *Options) { + o.IgnoreHeadingOnly = ignoreHeadingOnly + } +}