Skip to content

Commit

Permalink
/connect_wavesを追加 (#109)
Browse files Browse the repository at this point in the history
* /connect_wavesを追加

* エラーハンドリングを修正

* waveの代わりにpysoundfileを使用

* リストが空の時422を返すよう修正

* チャンネルが異なる場合の処理を追加

* テストが通るよう修正

* 説明文を改善

Co-authored-by: Hiroshiba <[email protected]>

* 処理を関数として切り出し

* Apply suggestions from code review

Co-authored-by: Hiroshiba <[email protected]>
  • Loading branch information
takana-v and Hiroshiba authored Sep 24, 2021
1 parent c301668 commit bb681d0
Showing 1 changed file with 61 additions and 0 deletions.
61 changes: 61 additions & 0 deletions run.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
import argparse
import base64
import io
import sys
import zipfile
from pathlib import Path
from tempfile import NamedTemporaryFile, TemporaryFile
from typing import List, Optional

import numpy as np
import soundfile
import uvicorn
from fastapi import FastAPI, HTTPException, Response
Expand Down Expand Up @@ -182,6 +185,35 @@ def create_accent_phrases(text: str, speaker_id: int) -> List[AccentPhrase]:
speaker_id=speaker_id,
)

def decode_base64_waves(waves: List[str]):
if len(waves) == 0:
raise HTTPException(status_code=422, detail="wavファイルが含まれていません")

waves_nparray = []
for i in range(len(waves)):
try:
wav_bin = base64.standard_b64decode(waves[i])
except ValueError:
raise HTTPException(status_code=422, detail="base64デコードに失敗しました")
try:
_data, _sampling_rate = soundfile.read(io.BytesIO(wav_bin))
except Exception:
raise HTTPException(status_code=422, detail="wavファイルを読み込めませんでした")
if i == 0:
sampling_rate = _sampling_rate
channels = _data.ndim
else:
if sampling_rate != _sampling_rate:
raise HTTPException(status_code=422, detail="ファイル間でサンプリングレートが異なります")
if channels != _data.ndim:
if channels == 1:
_data = _data.T[0]
else:
_data = np.array([_data, _data]).T
waves_nparray.append(_data)

return waves_nparray, sampling_rate

@app.post(
"/audio_query",
response_model=AudioQuery,
Expand Down Expand Up @@ -336,6 +368,35 @@ def multi_synthesis(queries: List[AudioQuery], speaker: int):

return FileResponse(f.name, media_type="application/zip")

@app.post(
"/connect_waves",
response_class=FileResponse,
responses={
200: {
"content": {
"audio/wav": {"schema": {"type": "string", "format": "binary"}}
},
}
},
tags=["その他"],
summary="base64エンコードされた複数のwavデータを一つに結合する",
)
def connect_waves(waves: List[str]):
"""
base64エンコードされたwavデータを一纏めにし、wavファイルで返します。
"""
waves_nparray, sampling_rate = decode_base64_waves(waves)

with NamedTemporaryFile(delete=False) as f:
soundfile.write(
file=f,
data=np.concatenate(waves_nparray),
samplerate=sampling_rate,
format="WAV",
)

return FileResponse(f.name, media_type="audio/wav")

@app.get("/version", tags=["その他"])
def version() -> str:
return (root_dir / "VERSION.txt").read_text()
Expand Down

0 comments on commit bb681d0

Please sign in to comment.