๋ณธ๋ฌธ ๋ฐ”๋กœ๊ฐ€๊ธฐ

์Œ์„ฑ์ฒ˜๋ฆฌ

[๋”ฅ๋Ÿฌ๋‹]CNN, RNN, LSTM ์ด๋ก  ๋ฐ ๊ตฌํ˜„

728x90
728x90

 

 

 

1. CNN(Convolutional Neural Network)

 

 

๐Ÿ“Œ CNN

CNN์€ ์˜์ƒ ์ฒ˜๋ฆฌ ๋ถ„์•ผ์—์„œ ์‹œ์ž‘๋œ ์‹ ๊ฒฝ๋ง ๊ตฌ์กฐ์ด์ง€๋งŒ, ์ตœ๊ทผ์—๋Š” ์Œ์„ฑ ๋ฐ์ดํ„ฐ๋ฅผ 2์ฐจ์› ํ˜•ํƒœ(์ŠคํŽ™ํŠธ๋กœ๊ทธ๋žจ, MFCC ๋“ฑ)๋กœ ๋ณ€ํ™˜ํ•˜์—ฌ ์ž…๋ ฅํ•จ์œผ๋กœ์จ ์Œ์„ฑ์ธ์‹์—๋„ ๋„๋ฆฌ ํ™œ์šฉ๋˜๊ณ  ์žˆ๋‹ค.

์ฃผ๋กœ ์Œํ–ฅ ๋ชจ๋ธ์˜ ์ „์ฒ˜๋ฆฌ ๊ณ„์ธต ๋˜๋Š” ์ „์ฒด ๋ชจ๋ธ๋กœ ์‚ฌ์šฉ๋˜๋ฉฐ, ๊ณต๊ฐ„์  ํŠน์ง•(์ฃผํŒŒ์ˆ˜ ๊ฐ„ ๊ด€๊ณ„, ์‹œ๊ฐ„์  ๋ณ€ํ™”)์„ ํšจ๊ณผ์ ์œผ๋กœ ์ถ”์ถœํ•œ๋‹ค.

 

โœ”๏ธ Convolution์˜ ๊ฐœ๋…

convolution์€ ์‹ ํ˜ธ ์ฒ˜๋ฆฌ์—์„œ ํŠน์ • ํŒจํ„ด์ด๋‚˜ ์ฃผํŒŒ์ˆ˜ ๋Œ€์—ญ์„ ์ถ”์ถœํ•˜๊ฑฐ๋‚˜, ๋…ธ์ด์ฆˆ๋ฅผ ์ œ๊ฑฐํ•˜๊ธฐ ์œ„ํ•ด ์‚ฌ์šฉํ•˜๋Š” ํ•ต์‹ฌ ์—ฐ์‚ฐ์ด๋‹ค. ์ž…๋ ฅ ์‹ ํ˜ธ์— ๋Œ€ํ•ด, ํ•„ํ„ฐ ๋˜๋Š” ์ปค๋„์„ ์ ์šฉํ•˜์—ฌ ์ƒˆ๋กœ์šด ์ถœ๋ ฅ์„ ์ƒ์„ฑํ•œ๋‹ค.

 

โ€ป ํ™”์ดํŠธ ๋…ธ์ด์ฆˆ ์‹ ํ˜ธ์— ๋Œ€ํ•ด 3๊ฐœ์˜ ๋Œ€์—ญํ†ต๊ณผ FIR ํ•„ํ„ฐ๋ฅผ ์ ์šฉํ•œ ํ›„, ์‹œ๊ฐ„ ์˜์—ญ๊ณผ ์ฃผํŒŒ์ˆ˜ ์˜์—ญ์—์„œ ํ•„ํ„ฐ๋ง ๊ฒฐ๊ณผ๋ฅผ ์‹œ๊ฐํ™”

import numpy as np
import matplotlib.pyplot as plt
from scipy.signal import firwin, lfilter
import librosa
import librosa.display

# 1. ์ž…๋ ฅ ์‹ ํ˜ธ ์ƒ์„ฑ (ํ™”์ดํŠธ ๋…ธ์ด์ฆˆ) : 1์ดˆ ๋™์•ˆ 16,000๊ฐœ์˜ ๋žœ๋ค ์ƒ˜ํ”Œ 
# ํ™”์ดํŠธ ๋…ธ์ด์ฆˆ๋Š” ๋ชจ๋“  ์ฃผํŒŒ์ˆ˜์— ๋™์ผํ•œ ์„ธ๊ธฐ๋ฅผ ๊ฐ€์ง€๋Š” ์‹ ํ˜ธ 
sr = 16000  # ์ƒ˜ํ”Œ๋ง ์ฃผํŒŒ์ˆ˜
T = 1.0     # 1์ดˆ
x = np.random.randn(int(sr * T))

# 2. FIR ๋Œ€์—ญํ†ต๊ณผ ํ•„ํ„ฐ ์„ค๊ณ„ ํ•จ์ˆ˜
# numtaps=101์€ ํ•„ํ„ฐ ๊ณ„์ˆ˜ ์ˆ˜๋กœ ํด์ˆ˜๋ก ์„ ๋ช…ํ•˜๊ณ  ์•ˆ์ •์ ์ด๋‹ค 
def design_bandpass_filter(center_hz, sr, width=100, numtaps=101):
    nyq = sr / 2
    low = max((center_hz - width) / nyq, 1e-4)
    high = min((center_hz + width) / nyq, 0.9999)
    if low >= high:
        raise ValueError(f"์ž˜๋ชป๋œ ์ฃผํŒŒ์ˆ˜ ๋ฒ”์œ„: low={low*nyq}, high={high*nyq}")
    return firwin(numtaps, [low, high], pass_zero=False)

# 3. ํ•„ํ„ฐ ๊ณ„์ˆ˜ ์ƒ์„ฑ(3๊ฐœ์˜ ์ค‘์‹ฌ ์ฃผํŒŒ์ˆ˜ ํ•„ํ„ฐ: 100Hz, 1kHz, 5kHz)
filters = {
    "100Hz": design_bandpass_filter(100, sr),
    "1000Hz": design_bandpass_filter(1000, sr),
    "5000Hz": design_bandpass_filter(5000, sr)
}

# 4. ํ•„ํ„ฐ ์ ์šฉ(๊ฐ ํ•„ํ„ฐ๋ฅผ x์— ์ ์šฉํ•ด์„œ ์ถœ๋ ฅ ์‹ ํ˜ธ ์ƒ์„ฑ)
# ์ถœ๋ ฅ: ์ค‘์‹ฌ ์ฃผํŒŒ์ˆ˜ ๋Œ€์—ญ๋งŒ ํ†ต๊ณผ์‹œํ‚จ ์‹ ํ˜ธ 
outputs = {k: lfilter(h, [1], x) for k, h in filters.items()}

# 5. ์‹œ๊ฐํ™” (์‹œ๊ฐ„ + ์ฃผํŒŒ์ˆ˜)
plt.figure(figsize=(14, 10))

for i, (label, y) in enumerate(outputs.items()):
    # ์‹œ๊ฐ„ ์˜์—ญ
    plt.subplot(len(filters), 2, 2*i + 1)
    librosa.display.waveshow(y, sr=sr)
    plt.title(f"{label} - Time Domain")
    plt.xlabel("Time (s)")
    plt.ylabel("Amplitude")

    # ์ฃผํŒŒ์ˆ˜ ์˜์—ญ (FFT)
    Y = np.abs(np.fft.rfft(y))
    freqs = np.fft.rfftfreq(len(y), d=1/sr)
    plt.subplot(len(filters), 2, 2*i + 2)
    plt.plot(freqs, Y)
    plt.title(f"{label} - Frequency Spectrum")
    plt.xlabel("Frequency (Hz)")
    plt.ylabel("Magnitude")
    plt.xlim(0, 8000)  # Nyquist ์ œํ•œ

plt.tight_layout()
plt.show()

 

  • 100Hz ํ•„ํ„ฐ ๊ฒฐ๊ณผ:
    • ํŒŒํ˜•์ด ๋งค์šฐ ์ฒœ์ฒœํžˆ ์ง„๋™
    • ์ŠคํŽ™ํŠธ๋Ÿผ์€ ์ €์—ญ๋Œ€์—๋งŒ ์ง‘์ค‘๋˜์–ด ์žˆ์Œ
  • 1000Hz ํ•„ํ„ฐ:
    • ํŒŒํ˜•์ด ์ค‘๊ฐ„ ์ •๋„๋กœ ์ง„๋™
    • ์ŠคํŽ™ํŠธ๋Ÿผ์€ 1kHz ๊ทผ์ฒ˜์— ๋ด‰์šฐ๋ฆฌ์— ์žˆ์Œ
  • 5000Hz ํ•„ํ„ฐ:
    • ํŒŒํ˜•์ด ๋งค์šฐ ๋น ๋ฅด๊ฒŒ ์ง„๋™(์„ธ๋ฐ€ํ•œ ์ง„๋™)
    • ์ŠคํŽ™ํŠธ๋Ÿผ์€ 5kHz ๋Œ€์—ญ๋งŒ ๊ฐ•์กฐ๋จ

 

 

 


 

 

 

2. RNN(Recurrent Neural Network)

 

 

๐Ÿ“Œ RNN

RNN์€ ์ˆœ์ฐจ ๋ฐ์ดํ„ฐ ์ฒ˜๋ฆฌ์— ํŠนํ™”๋œ ์ธ๊ณต์‹ ๊ฒฝ๋ง ๊ตฌ์กฐ์ด๋‹ค. ์ž์—ฐ์–ด ์ฒ˜๋ฆฌ, ์Œ์„ฑ ์ธ์‹, ์‹œ๊ณ„์—ด ์˜ˆ์ธก ๋“ฑ์— ์ฃผ๋กœ ์‚ฌ์šฉ๋œ๋‹ค.

๊ธฐ์กด์˜ ์‹ ๊ฒฝ๋ง์€ ์ž…๋ ฅ ๊ฐ„์˜ ์ˆœ์„œ๋ฅผ ๊ณ ๋ คํ•˜์ง€ ์•Š์ง€๋งŒ, RNN์€ ์ด์ „ ์‹œ์ ์˜ ์ถœ๋ ฅ์„ ๋‹ค์Œ ์‹œ์ ์˜ ์ž…๋ ฅ์— ์žฌ๊ท€์ ์œผ๋กœ ์ „๋‹ฌํ•จ์œผ๋กœ์จ ์‹œ๊ฐ„์ ์ธ ๋งฅ๋ฝ(๊ธฐ์–ต)์„ ๋ฐ˜์˜ํ•  ์ˆ˜ ์žˆ๋‹ค.

 

  • ์ž…๋ ฅ ๋ฒกํ„ฐ๊ฐ€ ์€๋‹‰์ธต์— ๋“ค์–ด๊ฐ
  • ์€๋‹‰์ธต์œผ๋กœ๋ถ€ํ„ฐ ์ถœ๋ ฅ ๋ฒกํ„ฐ๊ฐ€ ์ƒ์„ฑ๋จ
  • ์€๋‹‰์ธต์—์„œ ๋‚˜์™€ ๋‹ค์‹œ ์€๋‹‰์ธต์œผ๋กœ ์ž…๋ ฅ๋จ

 

๐Ÿ“Œ GSC ๋ฐ์ดํ„ฐ์…‹์˜ ์ผ๋ถ€ ์ˆซ์ž ์Œ์„ฑ์„ ๋ฌด์ž‘์œ„๋กœ ์„ ํƒํ•ด 3~5๊ฐœ๋ฅผ ์ด์–ด ๋ถ™์—ฌ ํ•˜๋‚˜์˜ ์—ฐ์† ์Œ์„ฑ์„ ๋งŒ๋“ค๊ณ , ํ•ด๋‹น ์Œ์„ฑ์„ MFCC๋กœ ๋ณ€ํ™˜ํ•˜์—ฌ RNN ๋ชจ๋ธ์„ ํ•™์Šตํ•œ ๋’ค, ๋ฌธ์ž ๋‹จ์œ„ ์˜ˆ์ธก์„ ์ˆ˜ํ–‰ํ•˜๋Š” ์Œ์„ฑ ์ธ์‹ ๋ชจ๋ธ ๊ตฌํ˜„

 

โ€ป ๋ฐ์ดํ„ฐ ๋‹ค์šด๋กœ๋“œ

!mkdir -p gsc
%cd gsc
!wget -q https://storage.googleapis.com/download.tensorflow.org/data/speech_commands_v0.02.tar.gz
!tar -xf speech_commands_v0.02.tar.gz
%cd ..

 

โ€ป ์ด์–ด๋ถ™์ธ ์Œ์„ฑ ์ƒ˜ํ”Œ ์ƒ์„ฑ

import os
import random
import librosa
import numpy as np
import soundfile as sf

digit_words = ['zero','one','two','three','four','five','six','seven','eight','nine']
src_root = "gsc"

# ์ด์–ด๋ถ™์ธ ์—ฐ์† ์Œ์„ฑ ์ƒ˜ํ”Œ ์ƒ์„ฑ
concat_dir = "gsc_concat"
os.makedirs(concat_dir, exist_ok=True)

samples = []
for i in range(200):
    digits = random.choices(digit_words, k=random.randint(3, 5))   # ex: ['one','six','nine'] 
    wavs = []
    for word in digits:
        folder = os.path.join(src_root, word)
        files = [f for f in os.listdir(folder) if f.endswith(".wav")]
        path = os.path.join(folder, random.choice(files))
        y, _ = librosa.load(path, sr=16000)
        wavs.append(y)
    full = np.concatenate(wavs)    # ์—ฌ๋Ÿฌ ๋‹จ์–ด์˜ ์Œ์„ฑ ๋ฐ์ดํ„ฐ๋ฅผ ์—ฐ๊ฒฐ 
    out_path = os.path.join(concat_dir, f"sample_{i}.wav")
    sf.write(out_path, full, 16000)
    samples.append((out_path, " ".join(digits)))

print("์ƒ˜ํ”Œ ์ƒ์„ฑ ์™„๋ฃŒ:", len(samples))

 

โ€ป ๋ฌธ์ž ์ธ๋ฑ์Šค ๋งคํ•‘ ๋ฐ ์ „์ฒ˜๋ฆฌ 

char_vocab = sorted(list("abcdefghijklmnopqrstuvwxyz '"))
char2idx = {ch: i for i, ch in enumerate(char_vocab)}
PAD_IDX = len(char2idx)   # ํŒจ๋”ฉ์šฉ ์ธ๋ฑ์Šค 
idx2char = {i: ch for ch, i in char2idx.items()}

# ๋ฌธ์ž ์‹œํ€€์Šค ์ •๋‹ต์„ ์ธ๋ฑ์Šค๋กœ ๋ณ€ํ™˜ ํ›„ ๊ธธ์ด 100์— ๋งž์ถฐ ํŒจ๋”ฉ 
def text_to_seq(text, max_len=100):
    seq = [char2idx[c] for c in text if c in char2idx]
    seq += [PAD_IDX] * (max_len - len(seq))
    return seq[:max_len]

# MFCC ํŠน์ง• ์ถ”์ถœ
def extract_mfcc(path, max_len=100):
    y, sr = librosa.load(path, sr=16000)
    mfcc = librosa.feature.mfcc(y=y, sr=sr, n_mfcc=13)    # [13, T]
    if mfcc.shape[1] < max_len:
        pad = np.zeros((13, max_len - mfcc.shape[1]))
        mfcc = np.concatenate((mfcc, pad), axis=1)
    return mfcc[:, :max_len].T  # [T, F]    # [T, 13]

X, y = [], []
for path, label in samples:
    X.append(extract_mfcc(path))
    y.append(text_to_seq(label))

# ํ•™์Šต ๋ฐ์ดํ„ฐ ๊ตฌ์„ฑ 
X = np.array(X)     # [N, T=100, 13]
y = np.array(y)     # [N, 100]

 

โ€ป ํ•™์Šต์šฉ ๋ฐ์ดํ„ฐ๋กœ ๋ณ€ํ™˜

from sklearn.model_selection import train_test_split
import torch
from torch.utils.data import DataLoader, TensorDataset

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)

X_train = torch.tensor(X_train, dtype=torch.float32)
X_test = torch.tensor(X_test, dtype=torch.float32)
y_train = torch.tensor(y_train, dtype=torch.long)
y_test = torch.tensor(y_test, dtype=torch.long)

train_loader = DataLoader(TensorDataset(X_train, y_train), batch_size=32, shuffle=True)

 

โ€ป RNN ๋ชจ๋ธ ์ •์˜ 

import torch.nn as nn

class RNNModel(nn.Module):
    def __init__(self, input_dim=13, hidden_dim=128, output_dim=len(char2idx)+1):
        super().__init__()
        self.rnn = nn.RNN(input_dim, hidden_dim, batch_first=True)     # [B, T, 13] -> [B, T, H]
        self.fc = nn.Linear(hidden_dim, output_dim)    # ๊ฐ ์‹œ๊ฐ„๋งˆ๋‹ค ๋ฌธ์ž ๋ถ„ํฌ ์˜ˆ์ธก 

    def forward(self, x):
        out, _ = self.rnn(x)       # [B, T, H]
        out = self.fc(out)         # [B, T, V]
        return out

    def backward(self, loss):
        loss.backward()

 

โ€ป ๋ชจ๋ธ ํ•™์Šต ( epoch = 1000)

import matplotlib.pyplot as plt

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = RNNModel().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss(ignore_index=PAD_IDX)    # ์†์‹ค ๊ณ„์‚ฐ 

loss_history = []  # ์†์‹ค ์ €์žฅ์šฉ ๋ฆฌ์ŠคํŠธ

for epoch in range(1000):
    model.train()
    total_loss = 0
    for xb, yb in train_loader:
        xb, yb = xb.to(device), yb.to(device)
        pred = model(xb)                    # [B, T, V]
        pred = pred.view(-1, pred.shape[-1])
        yb = yb.view(-1)                # ์ „์ฒด ์‹œํ€€์Šค๋ฅผ ํ•œ ๋ฒˆ์— ๊ณ„์‚ฐ 
        loss = criterion(pred, yb)
        model.backward(loss)            # ์†์‹ค ๊ณ„์‚ฐ ํ›„ backpropagation ์‹คํ–‰ 
        optimizer.step()                
        optimizer.zero_grad()           # ํŒŒ๋ผ๋ฏธํ„ฐ ์—…๋ฐ์ดํŠธ ๋ฐ ์ดˆ๊ธฐํ™” 
        total_loss += loss.item()

    avg_loss = total_loss / len(train_loader)
    loss_history.append(avg_loss)
    print(f"Epoch {epoch+1}: Loss={avg_loss:.4f}")

# ๋ชจ๋ธ ์ €์žฅ
torch.save(model.state_dict(), "rnn_speech_model.pth")
print("๋ชจ๋ธ ์ €์žฅ ์™„๋ฃŒ: rnn_speech_model.pth")

# ์†์‹ค ํ•จ์ˆ˜ ๊ทธ๋ž˜ํ”„ ์ถœ๋ ฅ
plt.figure(figsize=(10, 4))
plt.plot(loss_history, label="Train Loss")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.title("Training Loss over Epochs")
plt.grid(True)
plt.legend()
plt.tight_layout()
plt.show()

โ€ป ์ถ”๋ก  ๋ฐ ์˜ค๋””์˜ค ์žฌ์ƒํ•จ์ˆ˜ ์ •์˜

from IPython.display import Audio, display

def load_model_for_inference(model_path="rnn_speech_model.pth"):
    model = RNNModel().to(device)
    model.load_state_dict(torch.load(model_path, map_location=device))
    model.eval()
    return model

def infer_and_play_audio(sample_path, label_text=None, model_path="rnn_speech_model.pth"):
    # ์˜ค๋””์˜ค ์žฌ์ƒ
    print(f"[์Œ์„ฑ ์žฌ์ƒ]: {sample_path}")
    display(Audio(sample_path))

    # ๋ชจ๋ธ ๋กœ๋“œ
    model = load_model_for_inference(model_path)

    # ์ž…๋ ฅ ์ „์ฒ˜๋ฆฌ
    mfcc = extract_mfcc(sample_path)
    x = torch.tensor(mfcc, dtype=torch.float32).unsqueeze(0).to(device)  # [1, T, F]

    # ์ถ”๋ก 
    with torch.no_grad():
        output = model(x)  # [1, T, V]
        pred_idx = output.argmax(2)[0]     # ๊ฐ€์žฅ ํ™•๋ฅ  ๋†’์€ ๋ฌธ์ž ์ธ๋ฑ์Šค 
        pred_text = ''.join([idx2char[i.item()] for i in pred_idx if i.item() != PAD_IDX])

    # ์ถœ๋ ฅ
    if label_text:
        print("[์ •๋‹ต ๋ฌธ์žฅ]:", label_text)
    print("[๋ชจ๋ธ ์˜ˆ์ธก]:", pred_text)

 

 

 

โ€ป epoch = 10000์œผ๋กœ ํ•™์Šต ์‹œ 

 

 

 

 

 


 

 

 

3. LSTM (Long Short-Term Memory)

 

๐Ÿ“Œ LSTM

์ˆœ์ฐจ์ ์ธ ๋ฐ์ดํ„ฐ๋ฅผ ์ฒ˜๋ฆฌํ•  ๋•Œ, ๊ณผ๊ฑฐ ์ •๋ณด๋ฅผ ์žฅ๊ธฐ์ ์œผ๋กœ ๋ณด์กดํ•˜๋ฉด์„œ ํ•„์š”ํ•œ ์ •๋ณด๋งŒ ํ•™์Šต์— ๋ฐ˜์˜ํ•  ์ˆ˜ ์žˆ๋„๋ก ๊ณ ์•ˆ๋œ RNN ๊ณ„์—ด์˜ ์‹ ๊ฒฝ๋ง์ด๋‹ค. ์ผ๋ฐ˜์ ์ธ RNN์ด ์‹œํ€€์Šค ๊ธธ์ด๊ฐ€ ๊ธธ์–ด์งˆ์ˆ˜๋ก ๊ธฐ์šธ๊ธฐ ์†Œ์‹ค ๋ฌธ์ œ๋กœ ์ธํ•ด ๊ณผ๊ฑฐ ์ •๋ณด๋ฅผ ์žŠ๋Š” ๋ฌธ์ œ๋ฅผ ๊ฐœ์„ ํ•œ ๊ตฌ์กฐ์ด๋‹ค.

LSTM์€ ์‹œ๊ณ„์—ด ๋ฐ์ดํ„ฐ๋ฅผ ์ฒ˜๋ฆฌํ•  ๋•Œ, ๋งค ์‹œ๊ฐ„ ๋‹จ๊ณ„์—์„œ ์•„๋ž˜์™€ ๊ฐ™์€ ์„ธ ๊ฐ€์ง€ ์ •๋ณด ํ๋ฆ„์„ ์ฒ˜๋ฆฌํ•œ๋‹ค.

  • ์…€ ์ƒํƒœ(cell state): ์žฅ๊ธฐ ๊ธฐ์–ต์„ ์ „๋‹ฌํ•˜๋Š” ๋ฒกํ„ฐ
  • ์€๋‹‰ ์ƒํƒœ(hidden state): ๋‹ค์Œ ๊ณ„์ธต ๋˜๋Š” ๋‹ค์Œ ์‹œ๊ฐ„ ๋‹จ๊ณ„๋กœ ์ถœ๋ ฅํ•  ์ •๋ณด
  • ๊ฒŒ์ดํŠธ๋“ค: ์ •๋ณด๋ฅผ ์–ผ๋งˆ๋‚˜ ์œ ์ง€ํ• ์ง€, ๋ฒ„๋ฆด์ง€ ์ถœ๋ ฅํ• ์ง€๋ฅผ ๊ฒฐ์ •

 

  • ์ž‘๋™ ๋ฐฉ์‹
    • ์ž…๋ ฅ ๊ฒŒ์ดํŠธ๋Š” ํ˜„์žฌ ์ž…๋ ฅ์—์„œ ์…€ ์ƒํƒœ์— ์–ผ๋งˆ๋‚˜ ๋ฐ˜์˜ํ• ์ง€๋ฅผ ๊ฒฐ์ •
    • ๋ง๊ฐ ๊ฒŒ์ดํŠธ๋Š” ์ด์ „ ์…€ ์ƒํƒœ์˜ ์ •๋ณด๋ฅผ ์–ผ๋งˆ๋‚˜ ์žŠ์„์ง€๋ฅผ ๊ฒฐ์ •
    • ์ถœ๋ ฅ ๊ฒŒ์ดํŠธ๋Š” ์ตœ์ข… ์€๋‹‰ ์ƒํƒœ๋ฅผ ์–ผ๋งˆ๋‚˜ ๋‚ด๋ณด๋‚ผ์ง€๋ฅผ ์กฐ์ ˆ
    • ์ด๋Ÿฌํ•œ ๊ฒŒ์ดํŠธ ์—ฐ์‚ฐ์€ ์‹œ๊ทธ๋ชจ์ด๋“œ ํ•จ์ˆ˜๋กœ 0~1 ๊ฐ’์„ ์ถœ๋ ฅํ•˜์—ฌ ์ •๋ณด ํ๋ฆ„์„ ์กฐ์ ˆ 

 

import torch.nn as nn

class LSTMModel(nn.Module):
	# ๋ชจ๋ธ ๊ตฌ์กฐ ์ •์˜ 
    def __init__(self, input_dim=13, hidden_dim=128, output_dim=len(char2idx)+1):
        super().__init__()
        self.lstm = nn.LSTM(input_dim, hidden_dim, batch_first=True)
        self.fc = nn.Linear(hidden_dim, output_dim)

	# ์˜ˆ์ธก ๊ณผ์ • ์ •์˜
    def forward(self, x):
        out, _ = self.lstm(x)
        out = self.fc(out)
        return out
        
	# ์†์‹ค์— ๋Œ€ํ•œ ์—ญ์ „ํŒŒ ์ˆ˜ํ–‰
    def backward(self, loss):
        loss.backward()

 

728x90