-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #20 from hangeulbada/dev
Dev
- Loading branch information
Showing
268 changed files
with
19,599,391 additions
and
83 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,173 @@ | ||
from fastapi import FastAPI, HTTPException, Body | ||
from pydantic import BaseModel, Field | ||
from typing import Dict | ||
import os | ||
from dotenv import load_dotenv | ||
import json | ||
from crud import difficulty, pronounce, score | ||
|
||
app = FastAPI() | ||
load_dotenv() | ||
|
||
from enum import Enum | ||
class PronounceRule(str, Enum): | ||
구개음화 = "구개음화" | ||
연음화 = "연음화" | ||
경음화 = "경음화" | ||
유음화 = "유음화" | ||
비음화 = "비음화" | ||
음운규칙_없음 = "음운규칙 없음" | ||
겹받침_쓰기 = "겹받침 쓰기" | ||
기식음화 = "기식음화" | ||
|
||
class ClaudeRequest(BaseModel): | ||
difficulty: int = Field(default=3) | ||
rule: PronounceRule | ||
count: int = Field(default=5) | ||
|
||
@app.post("/phonological_rules") | ||
async def analysis_pronounce(text: Dict[int, str] = Body( | ||
example= | ||
{ | ||
"1": "맏이가 동생을 돌보았다", | ||
"2": "굳이 그렇게까지 할 필요는 없어", | ||
"3": "해돋이를 보러 산에 올랐다", | ||
"4": "옷이 낡아서 새로 샀다", | ||
"5": "같이 영화 보러 갈래?" | ||
} | ||
|
||
)): | ||
analysis = {} | ||
for n, t in text.items(): | ||
if not t: | ||
raise HTTPException(status_code=400, detail="text에 빈 문자열이 포함되어 있습니다.") | ||
analysis[n]=pronounce.pronounce_crud(t) | ||
return analysis | ||
|
||
@app.post("/claude") | ||
async def generate_claude(request: ClaudeRequest): | ||
try: | ||
import anthropic | ||
from datetime import datetime | ||
|
||
client_claude = anthropic.Anthropic( | ||
api_key=os.getenv('CLAUDE_API_KEY'), # 환경 변수를 설정했다면 생략 가능 | ||
) | ||
message = client_claude.messages.create( | ||
model="claude-3-5-sonnet-20241022", | ||
max_tokens=1000, | ||
# 다양한 결과값을 얻기 위해 temperature 값 조절 | ||
temperature=0.5, | ||
system="너는 음운 규칙별 받아쓰기 문제를 생성하는거야. 음운 규칙에는 구개음화, 연음화, 경음화, 유음화, 비음화, 음운규칙 없음, 겹받침 쓰기, 기식음화가 있어.\n내가 'n 난이도로 [m]유형으로 k문제 만들어줘' 라고 하면 맞춰서 받아쓰기 문제를 만들어줘.\nn: 1~5 (초등학교 기준, 1: 단어, 2: 쉬운 단어가 있는 간단한 문장, 3: 쉬운 단어가 있는 짧은 문장, 4: 짧은 문장, 5: 문장)\nm: 구개음화, 연음화, 경음화, 유음화, 비음화, 음운규칙 없음, 겹받침 쓰기, 기식음화\n답변 형식:\n문제번호:문제 형태로 json형식으로 반환", | ||
messages=[ | ||
{ | ||
"role": "user", | ||
"content": [ | ||
{ | ||
"type": "text", | ||
"text": f"{request.difficulty} 난이도로 [{request.rule}] 유형으로 {request.count}문제 만들어줘. (seed: {datetime.now().isoformat()})" | ||
} | ||
] | ||
} | ||
] | ||
) | ||
generated_problem = message.content[0].text | ||
generated_problem = json.loads(generated_problem) | ||
return generated_problem | ||
|
||
except Exception as e: | ||
raise HTTPException(status_code=500, detail=str(e)) | ||
|
||
class DifficultyRequest(BaseModel): | ||
text: str = Field("맏이가 동생을 돌보았다") | ||
|
||
@app.post("/difficulty") | ||
async def calc_difficulty(text: DifficultyRequest): | ||
b_grade={ | ||
'ㄱ':2, 'ㄴ':2, 'ㄹ':2, 'ㅁ':2, 'ㅇ':2, | ||
'ㄷ':3, 'ㅂ':3, 'ㅅ':3, 'ㅈ':3, 'ㅎ':3, 'ㅆ':3, | ||
'ㅊ':4, 'ㅋ':4, 'ㅌ':4, 'ㅍ':4, 'ㄲ':4, | ||
'ㄵ':5, 'ㄶ':5, | ||
'ㄺ':6, 'ㄻ':6, 'ㄼ':6, 'ㅀ':6, 'ㅄ':6, | ||
'ㄳ':7, 'ㄽ':7, 'ㄾ':7, 'ㄿ':7, | ||
} | ||
m_grade={ | ||
'ㅏ':1, 'ㅓ':1, 'ㅗ':1, 'ㅜ':1, 'ㅡ':1, 'ㅣ':1, | ||
'ㅐ':2, 'ㅔ':2, | ||
'ㅑ':3, 'ㅕ':3, 'ㅛ':3, | ||
'ㅚ':4, 'ㅟ':4, | ||
'ㅘ':5, 'ㅝ':5, 'ㅢ':5, | ||
'ㅖ':6, 'ㅙ':6, 'ㅞ':6, | ||
'ㅒ':7, 'ㅠ':7, | ||
} | ||
|
||
#pronounce 추출해서 해당하는 부분만 스코어링 | ||
s = text.text | ||
analysis = pronounce.pronounce_crud(s) | ||
|
||
spro = '' | ||
|
||
for k, v in analysis.items(): | ||
if not v: continue | ||
spro+=''.join(v) | ||
|
||
b_list, m_list = difficulty_dec(spro) | ||
b_grade_sum = sum(b_grade.get(b) for b in b_list) | ||
m_grade_sum = sum(m_grade.get(m) for m in m_list) | ||
total = (b_grade_sum+m_grade_sum)//5 | ||
if total>5: total=5 | ||
if total<1: total=1 | ||
return total | ||
|
||
class ScoreRequest(BaseModel): | ||
workbook: dict[int, str] = Field(description="문제집") | ||
answer: str = Field(description="답안 S3 주소") | ||
|
||
@app.post("/score") | ||
async def score_endpoint(s: ScoreRequest = Body( | ||
example={ | ||
"workbook": | ||
{ | ||
"1": "맏이가 동생을 돌보았다", | ||
"2": "굳이 그렇게까지 할 필요는 없어", | ||
"3": "해돋이를 보러 산에 올랐다", | ||
"4": "옷이 낡아서 새로 샀다", | ||
"5": "같이 영화 보러 갈래?", | ||
"6": "밥먹고 영화 할 사람?" | ||
}, | ||
"answer": "https://bada-static-bucket.s3.ap-northeast-2.amazonaws.com/1085767.png" | ||
} | ||
)): | ||
response = score.score_crud(s) | ||
|
||
# return { | ||
# "1": 80, | ||
# "2": 90, | ||
# "3": 47 | ||
# } | ||
|
||
return response | ||
|
||
|
||
@app.get("/") | ||
async def root(): | ||
return {"message": "한글바다 AI 서버입니다."} | ||
|
||
if __name__ == "__main__": | ||
import uvicorn | ||
uvicorn.run("main:app", host="127.0.0.1", port=8000, reload=True) | ||
|
||
# uvicorn main:app --reload | ||
|
||
def difficulty_dec(s: str): | ||
res = difficulty.decomposition(s) | ||
b_list = [] | ||
m_list = [] | ||
strip_list = [[col for col in row if col.strip()] for row in res] | ||
|
||
for i in strip_list: | ||
if len(i)==0: continue | ||
m_list.append(i[1]) | ||
if len(i) == 3: | ||
b_list.append(i[2]) | ||
return b_list, m_list |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
annotated-types==0.7.0 | ||
anthropic==0.39.0 | ||
anyio==4.6.2.post1 | ||
certifi==2024.8.30 | ||
click==8.1.7 | ||
colorama==0.4.6 | ||
distro==1.9.0 | ||
emoji==1.2.0 | ||
exceptiongroup==1.2.2 | ||
fastapi==0.115.4 | ||
h11==0.14.0 | ||
httpcore==1.0.6 | ||
httpx==0.27.2 | ||
idna==3.10 | ||
iniconfig==2.0.0 | ||
jiter==0.7.0 | ||
numpy==2.0.2 | ||
packaging==24.1 | ||
pecab==1.0.8 | ||
pluggy==1.5.0 | ||
pyarrow==18.0.0 | ||
pydantic==2.9.2 | ||
pydantic_core==2.23.4 | ||
pytest==8.3.3 | ||
python-dotenv==1.0.1 | ||
regex==2024.11.6 | ||
sniffio==1.3.1 | ||
starlette==0.41.2 | ||
tomli==2.0.2 | ||
typing_extensions==4.12.2 | ||
uvicorn==0.32.0 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
import re | ||
def decomposition(korean_word: str): | ||
korean_word = re.sub(r'[!"#$%&\'()*+,-./:;<=>?@\[\]^_\`{|}~\\\\]','', korean_word) | ||
# 초성 리스트. 00 ~ 18 | ||
CHOSUNG_LIST = ['ㄱ', 'ㄲ', 'ㄴ', 'ㄷ', 'ㄸ', 'ㄹ', 'ㅁ', 'ㅂ', 'ㅃ', 'ㅅ', 'ㅆ', 'ㅇ', 'ㅈ', 'ㅉ', 'ㅊ', 'ㅋ', 'ㅌ', 'ㅍ', 'ㅎ'] | ||
# 중성 리스트. 00 ~ 20 | ||
JUNGSUNG_LIST = ['ㅏ', 'ㅐ', 'ㅑ', 'ㅒ', 'ㅓ', 'ㅔ', 'ㅕ', 'ㅖ', 'ㅗ', 'ㅘ', 'ㅙ', 'ㅚ', 'ㅛ', 'ㅜ', 'ㅝ', 'ㅞ', 'ㅟ', 'ㅠ', 'ㅡ', 'ㅢ', 'ㅣ'] | ||
# 종성 리스트. 00 ~ 27 + 1(1개 없음) | ||
JONGSUNG_LIST = [' ', 'ㄱ', 'ㄲ', 'ㄳ', 'ㄴ', 'ㄵ', 'ㄶ', 'ㄷ', 'ㄹ', 'ㄺ', 'ㄻ', 'ㄼ', 'ㄽ', 'ㄾ', 'ㄿ', 'ㅀ', 'ㅁ', 'ㅂ', 'ㅄ', 'ㅅ', 'ㅆ', 'ㅇ', 'ㅈ', 'ㅊ', 'ㅋ', 'ㅌ', 'ㅍ', 'ㅎ'] | ||
|
||
res = [] | ||
for w in list(korean_word.strip()): | ||
## 영어인 경우 구분해서 작성함. | ||
if '가'<=w<='힣': | ||
## 588개 마다 초성이 바뀜. | ||
ch1 = (ord(w) - ord('가'))//588 | ||
## 중성은 총 28가지 종류 | ||
ch2 = ((ord(w) - ord('가')) - (588*ch1)) // 28 | ||
ch3 = (ord(w) - ord('가')) - (588*ch1) - 28*ch2 | ||
res.append([CHOSUNG_LIST[ch1], JUNGSUNG_LIST[ch2], JONGSUNG_LIST[ch3]]) | ||
else: | ||
res.append([w]) | ||
# print (res) | ||
|
||
return res |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,147 @@ | ||
# Import | ||
import io | ||
import re | ||
import cv2 | ||
import easyocr | ||
import numpy as np | ||
|
||
|
||
def group_text_by_coord(texts, coordinates, y_threshold=40): | ||
""" | ||
Group Text by Coordinates | ||
Parameters: | ||
- texts: List of Texts | ||
- coordinates: List of Coordinates [[x1, y1], [x2, y2], ...] | ||
- y_threshold: Y Difference Threshold | ||
Returns: | ||
- [['str1', 'str1-1'], ['str2', 'str2-1']] | ||
""" | ||
# Sort Coordinates | ||
sorted_coords = sorted(coordinates, key=lambda p: p[1]) | ||
|
||
# List of Coordinate and Text Pairs | ||
coord_text_pairs = list(zip(coordinates, texts)) | ||
|
||
groups = [] | ||
current_group = [] | ||
current_y = sorted_coords[0][1] | ||
|
||
# Sort Coordinate and Text Pairs by Y | ||
sorted_pairs = sorted(coord_text_pairs, key=lambda x: x[0][1]) | ||
|
||
for coord, text in sorted_pairs: | ||
if not current_group or abs(coord[1] - current_y) <= y_threshold: | ||
current_group.append((coord, text)) | ||
else: | ||
# Sort Group by X | ||
groups.append(sorted(current_group, key=lambda x: x[0][0])) | ||
current_group = [(coord, text)] | ||
current_y = coord[1] | ||
|
||
# Append Last Group | ||
if current_group: | ||
groups.append(sorted(current_group, key=lambda x: x[0][0])) | ||
|
||
# Get Texts Each Group | ||
text_groups = [[pair[1] for pair in group] for group in groups] | ||
|
||
return text_groups | ||
|
||
|
||
def text_preprocess(infer_text, first_coord, coord, y_thres): | ||
number_count = 0 | ||
number_list = ['1', '2', '3', '4', '5', '6', '7', '8', '9', '10', '11', '12', '13', '14', '15', '16', '17', '18', '19', '20'] | ||
output_text = [] | ||
output_coord = [] | ||
|
||
for i, text in enumerate(infer_text): | ||
if text in number_list: | ||
number_count += 1 | ||
|
||
for i, text in enumerate(infer_text): | ||
# Case 1. Remove Full Stop Char | ||
if text == ".": | ||
text = re.sub(r'.', '', text) | ||
|
||
# Case 2. Remove Number Pattern - e.g., 1., 1- | ||
text = re.sub(r'^[0-9]{1,2}[.-]', '', text) | ||
|
||
# Case 2-1. If Number Count More Than 5, Remove Number | ||
if number_count >= 5: | ||
text = re.sub(r'^[0-9]{1,2}', '', text) | ||
|
||
# Case 2-2. If box size of a number is too small, remove number | ||
if abs(float(coord[i][2][1]) - float(coord[i][0][1])) <= y_thres: | ||
text = re.sub(r'^[0-9]{1,2}', '', text) | ||
|
||
# Case 3. Remove Special Symbol | ||
text = re.sub(r'^[!?~/\@#$%^&*,.-_+=]*|[\/@#$%^&*_+=]*$', '', text) | ||
|
||
# Case 4. Remove Alphabet | ||
text = re.sub(r'^[a-zA-Z]*|[a-zA-Z]*$', '', text) | ||
|
||
# Case 5. Remove Front/End Space | ||
text = re.sub(r'^\s*|\s*$', '', text) | ||
|
||
# Case 6. Replace Last Hyphen to Full Stop | ||
result = text.replace("-", ".") | ||
|
||
if result != "": | ||
output_text.append(result) | ||
output_coord.append(first_coord[i]) | ||
|
||
return output_text, output_coord | ||
|
||
|
||
def infer_ocr(filepath): # `filepath` is S3 Path | ||
# Initialize EasyOCR Reader | ||
reader = easyocr.Reader( | ||
['ko'], | ||
model_storage_directory='ml/model', | ||
user_network_directory='ml/user_network', | ||
recog_network='custom', | ||
download_enabled=False, | ||
) | ||
|
||
# OCR 수행 | ||
result = reader.readtext(filepath, width_ths=0.2) | ||
|
||
# Confidence Threshold 값 정의 | ||
conf_thres = 0.1 | ||
|
||
coord = [] | ||
first_coord = [] | ||
infer_text = [] | ||
infer_conf = [] | ||
y_thres_list = [] | ||
y_thres = 50 | ||
for i, rst in enumerate(result): | ||
if rst[2] >= conf_thres: # If confidence more than threshold, append element to list | ||
tmp = [] | ||
for j in rst[0]: | ||
tmp.append([j[0], j[1]]) | ||
first_coord.append(tmp[0]) | ||
coord.append(tmp) | ||
infer_text.append(rst[1]) | ||
infer_conf.append(rst[2]) | ||
|
||
# Calculate Y Threshold | ||
for i in coord: | ||
height = abs(float(i[0][1]) - float(i[2][1])) | ||
y_thres_list.append(height) | ||
y_thres = np.mean(y_thres_list) | ||
|
||
# Text Preprocessing | ||
infer_text, first_coord = text_preprocess(infer_text, first_coord, coord, y_thres) | ||
|
||
# Group Text by Coord | ||
grouped_texts = group_text_by_coord(infer_text, first_coord, y_thres) | ||
|
||
infer_proc_text = {} | ||
for i, group in enumerate(grouped_texts): | ||
tmp = " ".join(group) | ||
infer_proc_text[str(i + 1)] = tmp | ||
|
||
return {"results": infer_proc_text} |
Oops, something went wrong.