58 lines
1.7 KiB
Python
58 lines
1.7 KiB
Python
import numpy as np
|
|
|
|
class LSTM:
|
|
def __init__(self, input_dim, hidden_dim):
|
|
# Initialize weights and biases
|
|
self.Wf = np.random.rand(hidden_dim, hidden_dim + input_dim)
|
|
self.bf = np.random.rand(hidden_dim, 1)
|
|
|
|
self.Wi = np.random.rand(hidden_dim, hidden_dim + input_dim)
|
|
self.bi = np.random.rand(hidden_dim, 1)
|
|
|
|
self.WC = np.random.rand(hidden_dim, hidden_dim + input_dim)
|
|
self.bC = np.random.rand(hidden_dim, 1)
|
|
|
|
self.Wo = np.random.rand(hidden_dim, hidden_dim + input_dim)
|
|
self.bo = np.random.rand(hidden_dim, 1)
|
|
|
|
def sigmoid(self, x):
|
|
return 1 / (1 + np.exp(-x))
|
|
|
|
def tanh(self, x):
|
|
return np.tanh(x)
|
|
|
|
def forward(self, x_t, h_prev, C_prev):
|
|
# Combine previous hidden state and current input
|
|
combined = np.vstack((h_prev, x_t))
|
|
|
|
# Forget gate
|
|
f_t = self.sigmoid(np.dot(self.Wf, combined) + self.bf)
|
|
|
|
# Input gate
|
|
i_t = self.sigmoid(np.dot(self.Wi, combined) + self.bi)
|
|
C_tilde = self.tanh(np.dot(self.WC, combined) + self.bC)
|
|
|
|
# Cell state
|
|
C_t = f_t * C_prev + i_t * C_tilde
|
|
|
|
# Output gate
|
|
o_t = self.sigmoid(np.dot(self.Wo, combined) + self.bo)
|
|
h_t = o_t * self.tanh(C_t)
|
|
|
|
return h_t, C_t
|
|
|
|
# Example usage
|
|
input_dim = 5 # Input feature size
|
|
hidden_dim = 3 # Number of hidden units
|
|
lstm = LSTM(input_dim, hidden_dim)
|
|
|
|
# Sample inputs
|
|
h_prev = np.zeros((hidden_dim, 1)) # Previous hidden state
|
|
C_prev = np.zeros((hidden_dim, 1)) # Previous cell state
|
|
x_t = np.random.rand(input_dim, 1) # Current input
|
|
|
|
# Forward pass
|
|
h_t, C_t = lstm.forward(x_t, h_prev, C_prev)
|
|
|
|
print("Current hidden state:", h_t)
|
|
print("Current cell state:", C_t) |