- 데이터 준비 : get_data() 함수는 입력과 출력 데이터로 구성된 Tensor 배열을 제공합니다.
- 학습 파라미터 생성 : get_weights() 함수는 임의의 수를 포함하는 Tensor를 제공합니다. 이 임의의 수는 이 문제의 해법을 구하기 위해서 최적화 됩니다.
- 네트워크 모델 : simple_network() 는 선형 규칙을 적용하고 가중치에 입력 데이터를 곱하고 바이어스 항 (y = Wx + b)을 더하여 입력 데이터에 대한 출력을 생성합니다
- 오차 : loss_fn() 함수는 모델의 얼마나 좋은지에 대한 정보를 제공합니다.
- 옵티마이저 : optimize() 함수는 초기에 생성 된 임의의 가중치를 조정하여 모델이 목표 값을보다 정확하게 계산하도록합니다.
def get_data():
train_X = np.asarray([3.3,4.4,5.5,6.71,6.93,4.168,9.779,6.182,7.59,2.167,
7.042,10.791,5.313,7.997,5.654,9.27,3.1])
train_Y = np.asarray([1.7,2.76,2.09,3.19,1.694,1.573,3.366,2.596,2.53,1.221,
2.827,3.465,1.65,2.904,2.42,2.94,1.3])
dtype = torch.FloatTensor
X = Variable(torch.from_numpy(train_X).type(dtype),requires_grad=False).view(17,1)
y = Variable(torch.from_numpy(train_Y).type(dtype),requires_grad=False)
return X,y
def plot_variable(x,y,z='',**kwargs):
l = []
for a in [x,y]:
#1.0 업그레이드에서 필요 없는 코드
#if type(a) == Variable:
l.append(a.data.numpy())
plt.plot(l[0],l[1],z,**kwargs)
def get_weights():
w = Variable(torch.randn(1),requires_grad = True)
b = Variable(torch.randn(1),requires_grad=True)
return w,b
def simple_network(x):
y_pred = torch.matmul(x,w)+b
return y_pred
#backward는 학습 파라미터 w와 b의 변화 정도를 기울기로 계산한다.
def loss_fn(y,y_pred):
loss = (y_pred-y).pow(2).sum()
for param in [w,b]:
if not param.grad is None: param.grad.data.zero_()
loss.backward()
return loss.data
#학습 파라미터를 조정함으로써 모델의 성능을 향상시켰다.
def optimize(learning_rate):
w.data -= learning_rate * w.grad.data
b.data -= learning_rate * b.grad.data
learning_rate = 1e-4
x,y = get_data() # x - 학습 데이터, y - 목적 변수(Target Variables)
w,b = get_weights() # w,b - 학습 파라미터
for i in range(500):
y_pred = simple_network(x) # wx + b를 계산하는 함수
loss = loss_fn(y,y_pred) # y와 y_pred의 차의 제곱 합을 계산
if i % 50 == 0:
print(loss)
optimize(learning_rate) # 오차를 최소화하도록 w, b를 조정
input값을 모델에 순전파 시키고, 실제값과 순전파시킨 가설값을 손실함수에 넣는다
손실함수의 미분값을 최소화 하는 방향으로 optimize를 통해 학습 파라미터 값을 수정한다.
'딥러닝' 카테고리의 다른 글
코랩 런타임 유지 (0) | 2022.08.19 |
---|---|
Open CV (0) | 2022.08.02 |
가중치 규제 (Norm) (0) | 2022.07.29 |
손실함수 정리 (0) | 2022.07.27 |
Sigmoid 함수 미분 (0) | 2022.07.27 |