Skip to content

Commit

Permalink
documentloaders: add AssemblyAI document loader (#668)
Browse files Browse the repository at this point in the history
* documentloaders: added assemblyai document loader
  • Loading branch information
marcusolsson authored Mar 18, 2024
1 parent ebb5d1a commit 0218733
Show file tree
Hide file tree
Showing 5 changed files with 328 additions and 0 deletions.
2 changes: 2 additions & 0 deletions .golangci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ linters:
- ireturn # disabling temporarily

linters-settings:
cyclop:
max-complexity: 12
funlen:
lines: 90
depguard:
Expand Down
248 changes: 248 additions & 0 deletions documentloaders/assemblyai.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,248 @@
package documentloaders

import (
"context"
"encoding/json"
"errors"
"io"

"github.com/AssemblyAI/assemblyai-go-sdk"
"github.com/tmc/langchaingo/schema"
"github.com/tmc/langchaingo/textsplitter"
)

// ErrMissingAudioSource is returned when neither an audio URL nor a reader has
// been set using [WithAudioURL] or [WithAudioReader].
var ErrMissingAudioSource = errors.New("assemblyai: missing audio source")

// TranscriptFormat represents the format of the document page content.
type TranscriptFormat int

const (
// Single document with full transcript text.
TranscriptFormatText TranscriptFormat = iota

// Multiple documents with each sentence as page content.
TranscriptFormatSentences

// Multiple documents with each paragraph as page content.
TranscriptFormatParagraphs

// Single document with SRT formatted subtitles as page content.
TranscriptFormatSubtitlesSRT

// Single document with VTT formatted subtitles as page content.
TranscriptFormatSubtitlesVTT
)

// AssemblyAIAudioTranscriptLoader transcribes an audio file using AssemblyAI
// and loads the transcript.
//
// Audio files can be specified using either a URL or a reader.
//
// For a list of the supported audio and video formats, see the [FAQ].
//
// [FAQ]: https://www.assemblyai.com/docs/concepts/faq
type AssemblyAIAudioTranscriptLoader struct {
client *assemblyai.Client

// URL of the audio file to transcribe.
url string

// Reader of the audio file to transcribe.
r io.Reader

// Optional parameters for the transcription.
params *assemblyai.TranscriptOptionalParams

// Format of the document page content.
format TranscriptFormat
}

var _ Loader = &AssemblyAIAudioTranscriptLoader{}

// AssemblyAIOption is an option for the AssemblyAI loader.
type AssemblyAIOption func(loader *AssemblyAIAudioTranscriptLoader)

// NewAssemblyAIAudioTranscript returns a new instance
// AssemblyAIAudioTranscriptLoader.
func NewAssemblyAIAudioTranscript(apiKey string, opts ...AssemblyAIOption) *AssemblyAIAudioTranscriptLoader {
loader := &AssemblyAIAudioTranscriptLoader{
client: assemblyai.NewClient(apiKey),
format: TranscriptFormatText,
}

for _, opt := range opts {
opt(loader)
}

return loader
}

// WithAudioURL configures the loader to transcribe an audio file from a URL.
// The URL needs to be accessible from AssemblyAI's servers.
func WithAudioURL(url string) AssemblyAIOption {
return func(a *AssemblyAIAudioTranscriptLoader) {
a.url = url
}
}

// WithAudioReader configures the loader to transcribe a local audio file.
func WithAudioReader(r io.Reader) AssemblyAIOption {
return func(a *AssemblyAIAudioTranscriptLoader) {
a.r = r
}
}

// WithAudioReader configures the format of the document page content.
func WithTranscriptFormat(format TranscriptFormat) AssemblyAIOption {
return func(a *AssemblyAIAudioTranscriptLoader) {
a.format = format
}
}

// WithTranscriptParams configures the optional parameters for the transcription.
func WithTranscriptParams(params *assemblyai.TranscriptOptionalParams) AssemblyAIOption {
return func(a *AssemblyAIAudioTranscriptLoader) {
a.params = params
}
}

// Load transcribes an audio file, transcribes it using AssemblyAI, and returns
// them transcript as a document.
func (a *AssemblyAIAudioTranscriptLoader) Load(ctx context.Context) ([]schema.Document, error) {
transcript, err := a.transcribe(ctx)
if err != nil {
return nil, err
}

docs, err := a.formatTranscript(ctx, transcript)
if err != nil {
return nil, err
}

return docs, nil
}

// transcribe conditionally transcribes an audio file based on the specified
// source.
func (a *AssemblyAIAudioTranscriptLoader) transcribe(ctx context.Context) (assemblyai.Transcript, error) {
if a.url != "" {
return a.client.Transcripts.TranscribeFromURL(ctx, a.url, a.params)
}

if a.r != nil {
return a.client.Transcripts.TranscribeFromReader(ctx, a.r, a.params)
}

return assemblyai.Transcript{}, ErrMissingAudioSource
}

// formatTranscript returns a schema.Document for a transcript based on the
// specific format.
func (a *AssemblyAIAudioTranscriptLoader) formatTranscript(ctx context.Context, transcript assemblyai.Transcript) ([]schema.Document, error) {
switch a.format {
case TranscriptFormatSentences:
sentences, err := a.client.Transcripts.GetSentences(ctx, assemblyai.ToString(transcript.ID))
if err != nil {
return nil, err
}
return documentsFromSentences(sentences.Sentences)

case TranscriptFormatParagraphs:
paragraphs, err := a.client.Transcripts.GetParagraphs(ctx, assemblyai.ToString(transcript.ID))
if err != nil {
return nil, err
}
return documentsFromParagraphs(paragraphs.Paragraphs)

case TranscriptFormatSubtitlesSRT:
srt, err := a.client.Transcripts.GetSubtitles(ctx, assemblyai.ToString(transcript.ID), "srt", nil)
if err != nil {
return nil, err
}
return []schema.Document{{PageContent: string(srt)}}, nil

case TranscriptFormatSubtitlesVTT:
vtt, err := a.client.Transcripts.GetSubtitles(ctx, assemblyai.ToString(transcript.ID), "vtt", nil)
if err != nil {
return nil, err
}
return []schema.Document{{PageContent: string(vtt)}}, nil

case TranscriptFormatText:
fallthrough

default:
metadata, err := toMetadata(transcript)
if err != nil {
return nil, err
}
return []schema.Document{{PageContent: assemblyai.ToString(transcript.Text), Metadata: metadata}}, nil
}
}

func documentsFromSentences(sentences []assemblyai.TranscriptSentence) ([]schema.Document, error) {
docs := make([]schema.Document, 0, len(sentences))

for _, sentence := range sentences {
metadata, err := toMetadata(sentence)
if err != nil {
return nil, err
}

docs = append(docs, schema.Document{
PageContent: assemblyai.ToString(sentence.Text),
Metadata: metadata,
})
}

return docs, nil
}

func documentsFromParagraphs(paragraphs []assemblyai.TranscriptParagraph) ([]schema.Document, error) {
docs := make([]schema.Document, 0, len(paragraphs))

for _, paragraph := range paragraphs {
metadata, err := toMetadata(paragraph)
if err != nil {
return nil, err
}

docs = append(docs, schema.Document{
PageContent: assemblyai.ToString(paragraph.Text),
Metadata: metadata,
})
}

return docs, nil
}

// toMetadata converts a struct to a map representation to use as metadata.
func toMetadata(obj any) (map[string]any, error) {
b, err := json.Marshal(obj)
if err != nil {
return nil, err
}

var metadata map[string]any
if err := json.Unmarshal(b, &metadata); err != nil {
return nil, err
}

// Remove redundant transcript text.
delete(metadata, "text")

return metadata, nil
}

// LoadAndSplit transcribes the audio data and splits it into multiple documents
// using a text splitter.
func (a *AssemblyAIAudioTranscriptLoader) LoadAndSplit(ctx context.Context, splitter textsplitter.TextSplitter) ([]schema.Document, error) {
docs, err := a.Load(ctx)
if err != nil {
return nil, err
}

return textsplitter.SplitDocuments(splitter, docs)
}
58 changes: 58 additions & 0 deletions documentloaders/assemblyai_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
package documentloaders

import (
"context"
"os"
"testing"

aai "github.com/AssemblyAI/assemblyai-go-sdk"
"github.com/stretchr/testify/require"
)

func TestAssemblyAIAudioTranscriptLoader_Load(t *testing.T) {
t.Parallel()

ctx := context.Background()

var apiKey string
if apiKey = os.Getenv("ASSEMBLYAI_API_KEY"); apiKey == "" {
t.Skip("ASSEMBLYAI_API_KEY not set")
}

audioURL := "https://github.com/AssemblyAI-Examples/audio-examples/raw/main/20230607_me_canadian_wildfires.mp3"

loader := NewAssemblyAIAudioTranscript(
apiKey,
WithAudioURL(audioURL),
WithTranscriptFormat(TranscriptFormatText),
WithTranscriptParams(&aai.TranscriptOptionalParams{
RedactPII: aai.Bool(true),
RedactPIIPolicies: []aai.PIIPolicy{"person_name"},
}),
)

docs, err := loader.Load(ctx)
require.NoError(t, err)

require.Len(t, docs, 1)

require.NotEmpty(t, docs[0].PageContent)

redactPII, ok := docs[0].Metadata["redact_pii"].(bool)

require.True(t, ok)
require.True(t, redactPII)
}

func TestAssemblyAIAudioTranscriptLoader_toMetadata(t *testing.T) {
t.Parallel()

metadata, err := toMetadata(aai.TranscriptSentence{
Speaker: aai.String("1"),
Text: aai.String("This is a test sentence."),
})
require.NoError(t, err)

require.Equal(t, "1", metadata["speaker"])
require.Nil(t, metadata["text"])
}
4 changes: 4 additions & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ require (
cloud.google.com/go/iam v1.1.5 // indirect
cloud.google.com/go/longrunning v0.5.4 // indirect
dario.cat/mergo v1.0.0 // indirect
github.com/AssemblyAI/assemblyai-go-sdk v1.3.0 // indirect
github.com/Azure/go-ansiterm v0.0.0-20230124172434-306776ec8161 // indirect
github.com/Masterminds/goutils v1.1.1 // indirect
github.com/Masterminds/semver v1.5.0 // indirect
Expand All @@ -51,6 +52,7 @@ require (
github.com/aws/aws-sdk-go-v2/service/sts v1.28.1 // indirect
github.com/aws/smithy-go v1.20.1 // indirect
github.com/aymerick/douceur v0.2.0 // indirect
github.com/cenkalti/backoff v2.2.1+incompatible // indirect
github.com/cenkalti/backoff/v4 v4.2.1 // indirect
github.com/cockroachdb/errors v1.9.1 // indirect
github.com/cockroachdb/logtags v0.0.0-20211118104740-dabe8e521a4f // indirect
Expand Down Expand Up @@ -83,6 +85,7 @@ require (
github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect
github.com/golang/protobuf v1.5.3 // indirect
github.com/google/flatbuffers v23.5.26+incompatible // indirect
github.com/google/go-querystring v1.1.0 // indirect
github.com/google/s2a-go v0.1.7 // indirect
github.com/googleapis/enterprise-certificate-proxy v0.3.2 // indirect
github.com/googleapis/gax-go/v2 v2.12.0 // indirect
Expand Down Expand Up @@ -167,6 +170,7 @@ require (
google.golang.org/genproto/googleapis/api v0.0.0-20240123012728-ef4313101c80 // indirect
google.golang.org/genproto/googleapis/rpc v0.0.0-20240123012728-ef4313101c80 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
nhooyr.io/websocket v1.8.7 // indirect
)

require (
Expand Down
Loading

0 comments on commit 0218733

Please sign in to comment.