Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dev #20

Merged
merged 7 commits into from
Nov 27, 2024
Merged

Dev #20

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
173 changes: 173 additions & 0 deletions .ipynb_checkpoints/main-checkpoint.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
from fastapi import FastAPI, HTTPException, Body
from pydantic import BaseModel, Field
from typing import Dict
import os
from dotenv import load_dotenv
import json
from crud import difficulty, pronounce, score

app = FastAPI()
load_dotenv()

from enum import Enum
class PronounceRule(str, Enum):
구개음화 = "구개음화"
연음화 = "연음화"
경음화 = "경음화"
유음화 = "유음화"
비음화 = "비음화"
음운규칙_없음 = "음운규칙 없음"
겹받침_쓰기 = "겹받침 쓰기"
기식음화 = "기식음화"

class ClaudeRequest(BaseModel):
difficulty: int = Field(default=3)
rule: PronounceRule
count: int = Field(default=5)

@app.post("/phonological_rules")
async def analysis_pronounce(text: Dict[int, str] = Body(
example=
{
"1": "맏이가 동생을 돌보았다",
"2": "굳이 그렇게까지 할 필요는 없어",
"3": "해돋이를 보러 산에 올랐다",
"4": "옷이 낡아서 새로 샀다",
"5": "같이 영화 보러 갈래?"
}

)):
analysis = {}
for n, t in text.items():
if not t:
raise HTTPException(status_code=400, detail="text에 빈 문자열이 포함되어 있습니다.")
analysis[n]=pronounce.pronounce_crud(t)
return analysis

@app.post("/claude")
async def generate_claude(request: ClaudeRequest):
try:
import anthropic
from datetime import datetime

client_claude = anthropic.Anthropic(
api_key=os.getenv('CLAUDE_API_KEY'), # 환경 변수를 설정했다면 생략 가능
)
message = client_claude.messages.create(
model="claude-3-5-sonnet-20241022",
max_tokens=1000,
# 다양한 결과값을 얻기 위해 temperature 값 조절
temperature=0.5,
system="너는 음운 규칙별 받아쓰기 문제를 생성하는거야. 음운 규칙에는 구개음화, 연음화, 경음화, 유음화, 비음화, 음운규칙 없음, 겹받침 쓰기, 기식음화가 있어.\n내가 'n 난이도로 [m]유형으로 k문제 만들어줘' 라고 하면 맞춰서 받아쓰기 문제를 만들어줘.\nn: 1~5 (초등학교 기준, 1: 단어, 2: 쉬운 단어가 있는 간단한 문장, 3: 쉬운 단어가 있는 짧은 문장, 4: 짧은 문장, 5: 문장)\nm: 구개음화, 연음화, 경음화, 유음화, 비음화, 음운규칙 없음, 겹받침 쓰기, 기식음화\n답변 형식:\n문제번호:문제 형태로 json형식으로 반환",
messages=[
{
"role": "user",
"content": [
{
"type": "text",
"text": f"{request.difficulty} 난이도로 [{request.rule}] 유형으로 {request.count}문제 만들어줘. (seed: {datetime.now().isoformat()})"
}
]
}
]
)
generated_problem = message.content[0].text
generated_problem = json.loads(generated_problem)
return generated_problem

except Exception as e:
raise HTTPException(status_code=500, detail=str(e))

class DifficultyRequest(BaseModel):
text: str = Field("맏이가 동생을 돌보았다")

@app.post("/difficulty")
async def calc_difficulty(text: DifficultyRequest):
b_grade={
'ㄱ':2, 'ㄴ':2, 'ㄹ':2, 'ㅁ':2, 'ㅇ':2,
'ㄷ':3, 'ㅂ':3, 'ㅅ':3, 'ㅈ':3, 'ㅎ':3, 'ㅆ':3,
'ㅊ':4, 'ㅋ':4, 'ㅌ':4, 'ㅍ':4, 'ㄲ':4,
'ㄵ':5, 'ㄶ':5,
'ㄺ':6, 'ㄻ':6, 'ㄼ':6, 'ㅀ':6, 'ㅄ':6,
'ㄳ':7, 'ㄽ':7, 'ㄾ':7, 'ㄿ':7,
}
m_grade={
'ㅏ':1, 'ㅓ':1, 'ㅗ':1, 'ㅜ':1, 'ㅡ':1, 'ㅣ':1,
'ㅐ':2, 'ㅔ':2,
'ㅑ':3, 'ㅕ':3, 'ㅛ':3,
'ㅚ':4, 'ㅟ':4,
'ㅘ':5, 'ㅝ':5, 'ㅢ':5,
'ㅖ':6, 'ㅙ':6, 'ㅞ':6,
'ㅒ':7, 'ㅠ':7,
}

#pronounce 추출해서 해당하는 부분만 스코어링
s = text.text
analysis = pronounce.pronounce_crud(s)

spro = ''

for k, v in analysis.items():
if not v: continue
spro+=''.join(v)

b_list, m_list = difficulty_dec(spro)
b_grade_sum = sum(b_grade.get(b) for b in b_list)
m_grade_sum = sum(m_grade.get(m) for m in m_list)
total = (b_grade_sum+m_grade_sum)//5
if total>5: total=5
if total<1: total=1
return total

class ScoreRequest(BaseModel):
workbook: dict[int, str] = Field(description="문제집")
answer: str = Field(description="답안 S3 주소")

@app.post("/score")
async def score_endpoint(s: ScoreRequest = Body(
example={
"workbook":
{
"1": "맏이가 동생을 돌보았다",
"2": "굳이 그렇게까지 할 필요는 없어",
"3": "해돋이를 보러 산에 올랐다",
"4": "옷이 낡아서 새로 샀다",
"5": "같이 영화 보러 갈래?",
"6": "밥먹고 영화 할 사람?"
},
"answer": "https://bada-static-bucket.s3.ap-northeast-2.amazonaws.com/1085767.png"
}
)):
response = score.score_crud(s)

# return {
# "1": 80,
# "2": 90,
# "3": 47
# }

return response


@app.get("/")
async def root():
return {"message": "한글바다 AI 서버입니다."}

if __name__ == "__main__":
import uvicorn
uvicorn.run("main:app", host="127.0.0.1", port=8000, reload=True)

# uvicorn main:app --reload

def difficulty_dec(s: str):
res = difficulty.decomposition(s)
b_list = []
m_list = []
strip_list = [[col for col in row if col.strip()] for row in res]

for i in strip_list:
if len(i)==0: continue
m_list.append(i[1])
if len(i) == 3:
b_list.append(i[2])
return b_list, m_list
31 changes: 31 additions & 0 deletions .ipynb_checkpoints/requirements-checkpoint.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
annotated-types==0.7.0
anthropic==0.39.0
anyio==4.6.2.post1
certifi==2024.8.30
click==8.1.7
colorama==0.4.6
distro==1.9.0
emoji==1.2.0
exceptiongroup==1.2.2
fastapi==0.115.4
h11==0.14.0
httpcore==1.0.6
httpx==0.27.2
idna==3.10
iniconfig==2.0.0
jiter==0.7.0
numpy==2.0.2
packaging==24.1
pecab==1.0.8
pluggy==1.5.0
pyarrow==18.0.0
pydantic==2.9.2
pydantic_core==2.23.4
pytest==8.3.3
python-dotenv==1.0.1
regex==2024.11.6
sniffio==1.3.1
starlette==0.41.2
tomli==2.0.2
typing_extensions==4.12.2
uvicorn==0.32.0
25 changes: 25 additions & 0 deletions crud/.ipynb_checkpoints/difficulty-checkpoint.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
import re
def decomposition(korean_word: str):
korean_word = re.sub(r'[!"#$%&\'()*+,-./:;<=>?@\[\]^_\`{|}~\\\\]','', korean_word)
# 초성 리스트. 00 ~ 18
CHOSUNG_LIST = ['ㄱ', 'ㄲ', 'ㄴ', 'ㄷ', 'ㄸ', 'ㄹ', 'ㅁ', 'ㅂ', 'ㅃ', 'ㅅ', 'ㅆ', 'ㅇ', 'ㅈ', 'ㅉ', 'ㅊ', 'ㅋ', 'ㅌ', 'ㅍ', 'ㅎ']
# 중성 리스트. 00 ~ 20
JUNGSUNG_LIST = ['ㅏ', 'ㅐ', 'ㅑ', 'ㅒ', 'ㅓ', 'ㅔ', 'ㅕ', 'ㅖ', 'ㅗ', 'ㅘ', 'ㅙ', 'ㅚ', 'ㅛ', 'ㅜ', 'ㅝ', 'ㅞ', 'ㅟ', 'ㅠ', 'ㅡ', 'ㅢ', 'ㅣ']
# 종성 리스트. 00 ~ 27 + 1(1개 없음)
JONGSUNG_LIST = [' ', 'ㄱ', 'ㄲ', 'ㄳ', 'ㄴ', 'ㄵ', 'ㄶ', 'ㄷ', 'ㄹ', 'ㄺ', 'ㄻ', 'ㄼ', 'ㄽ', 'ㄾ', 'ㄿ', 'ㅀ', 'ㅁ', 'ㅂ', 'ㅄ', 'ㅅ', 'ㅆ', 'ㅇ', 'ㅈ', 'ㅊ', 'ㅋ', 'ㅌ', 'ㅍ', 'ㅎ']

res = []
for w in list(korean_word.strip()):
## 영어인 경우 구분해서 작성함.
if '가'<=w<='힣':
## 588개 마다 초성이 바뀜.
ch1 = (ord(w) - ord('가'))//588
## 중성은 총 28가지 종류
ch2 = ((ord(w) - ord('가')) - (588*ch1)) // 28
ch3 = (ord(w) - ord('가')) - (588*ch1) - 28*ch2
res.append([CHOSUNG_LIST[ch1], JUNGSUNG_LIST[ch2], JONGSUNG_LIST[ch3]])
else:
res.append([w])
# print (res)

return res
147 changes: 147 additions & 0 deletions crud/.ipynb_checkpoints/ocr-checkpoint.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
# Import
import io
import re
import cv2
import easyocr
import numpy as np


def group_text_by_coord(texts, coordinates, y_threshold=40):
"""
Group Text by Coordinates

Parameters:
- texts: List of Texts
- coordinates: List of Coordinates [[x1, y1], [x2, y2], ...]
- y_threshold: Y Difference Threshold

Returns:
- [['str1', 'str1-1'], ['str2', 'str2-1']]
"""
# Sort Coordinates
sorted_coords = sorted(coordinates, key=lambda p: p[1])

# List of Coordinate and Text Pairs
coord_text_pairs = list(zip(coordinates, texts))

groups = []
current_group = []
current_y = sorted_coords[0][1]

# Sort Coordinate and Text Pairs by Y
sorted_pairs = sorted(coord_text_pairs, key=lambda x: x[0][1])

for coord, text in sorted_pairs:
if not current_group or abs(coord[1] - current_y) <= y_threshold:
current_group.append((coord, text))
else:
# Sort Group by X
groups.append(sorted(current_group, key=lambda x: x[0][0]))
current_group = [(coord, text)]
current_y = coord[1]

# Append Last Group
if current_group:
groups.append(sorted(current_group, key=lambda x: x[0][0]))

# Get Texts Each Group
text_groups = [[pair[1] for pair in group] for group in groups]

return text_groups


def text_preprocess(infer_text, first_coord, coord, y_thres):
number_count = 0
number_list = ['1', '2', '3', '4', '5', '6', '7', '8', '9', '10', '11', '12', '13', '14', '15', '16', '17', '18', '19', '20']
output_text = []
output_coord = []

for i, text in enumerate(infer_text):
if text in number_list:
number_count += 1

for i, text in enumerate(infer_text):
# Case 1. Remove Full Stop Char
if text == ".":
text = re.sub(r'.', '', text)

# Case 2. Remove Number Pattern - e.g., 1., 1-
text = re.sub(r'^[0-9]{1,2}[.-]', '', text)

# Case 2-1. If Number Count More Than 5, Remove Number
if number_count >= 5:
text = re.sub(r'^[0-9]{1,2}', '', text)

# Case 2-2. If box size of a number is too small, remove number
if abs(float(coord[i][2][1]) - float(coord[i][0][1])) <= y_thres:
text = re.sub(r'^[0-9]{1,2}', '', text)

# Case 3. Remove Special Symbol
text = re.sub(r'^[!?~/\@#$%^&*,.-_+=]*|[\/@#$%^&*_+=]*$', '', text)

# Case 4. Remove Alphabet
text = re.sub(r'^[a-zA-Z]*|[a-zA-Z]*$', '', text)

# Case 5. Remove Front/End Space
text = re.sub(r'^\s*|\s*$', '', text)

# Case 6. Replace Last Hyphen to Full Stop
result = text.replace("-", ".")

if result != "":
output_text.append(result)
output_coord.append(first_coord[i])

return output_text, output_coord


def infer_ocr(filepath): # `filepath` is S3 Path
# Initialize EasyOCR Reader
reader = easyocr.Reader(
['ko'],
model_storage_directory='ml/model',
user_network_directory='ml/user_network',
recog_network='custom',
download_enabled=False,
)

# OCR 수행
result = reader.readtext(filepath, width_ths=0.2)

# Confidence Threshold 값 정의
conf_thres = 0.1

coord = []
first_coord = []
infer_text = []
infer_conf = []
y_thres_list = []
y_thres = 50
for i, rst in enumerate(result):
if rst[2] >= conf_thres: # If confidence more than threshold, append element to list
tmp = []
for j in rst[0]:
tmp.append([j[0], j[1]])
first_coord.append(tmp[0])
coord.append(tmp)
infer_text.append(rst[1])
infer_conf.append(rst[2])

# Calculate Y Threshold
for i in coord:
height = abs(float(i[0][1]) - float(i[2][1]))
y_thres_list.append(height)
y_thres = np.mean(y_thres_list)

# Text Preprocessing
infer_text, first_coord = text_preprocess(infer_text, first_coord, coord, y_thres)

# Group Text by Coord
grouped_texts = group_text_by_coord(infer_text, first_coord, y_thres)

infer_proc_text = {}
for i, group in enumerate(grouped_texts):
tmp = " ".join(group)
infer_proc_text[str(i + 1)] = tmp

return {"results": infer_proc_text}
Loading