본문 바로가기

강화학습 강의 복습노트

Part2 - 8. SARSA : TD기법을 활용한 최적 정책 찾기

지난 TD기법에 대한 포스트에서 다뤘던 TD기법에 대해 다시 복습해보자.

 

Part2 - 6. Temporal Difference(TD) 정책추정

1) DP와 MC 기법의 장단점 DP기법의 경우 각 상태와 행동의 관계를 최대한 활용해 계산량을 줄인다는 장점을 가진 반면에 환경에 대한 모델이 없으면 계산이 불가능하다는 단점을 가지고 있다. MC

hh-bigdata-career.tistory.com

 

그림1. TD기법의 가치 추산 식

지금까지 $V(s)$를 추산하여 $V(s) \rightarrow Q(s,a)$를 도출하는 방법에 대해서 배웠다.

 

그렇다면 Q(s, a)에 대해서도 비슷한 방식으로 추산할 수 있지 않을까??

 

TD(0)를 활용한 행동 가치함수 $Q^{\pi}$를 추산하는 과정을 살펴보자.

 

1) SARSA

그림2. SARSA update의 구조

SARSA업데이트는 상태 s에서 policy에 따라 행동 a를 한 후 보상 r을 받는다. 그 후 다음상태 s'로 이동하고, policy에 따라 다음행동 a'를 얻는다.

 

이렇게 얻은 a와 s,r을 통해 Q(s,a)를 업데이트 하는 방식이 SARSA 업데이트이다.

※ TMI : State Action Reward State Action이라서 SARSA이다.

 

그림3. SARSA의 Pseudo Code

SARSA의 의사코드를 살펴보자.

전체적인 흐름은 TD와 비슷하나 반복문이 중첩되어있다.

외부의 반복문은 Q(s,a)가 수렴할 때까지 진행되며, 내부의 반복문은 s가 종결상태일 때 까지 진행된다.

 

이 알고리즘을 활용하면 n-step SARSA 또한 구현이 가능하다.

그림4. n-step SARSA

n-step SARSA에서 target은 $q_{t}^{(n)}$으로 표현한다.

 

SARSA update의 장점은 모든 episode를 돌지않아도 policy improvement가 가능하다는 것이다.

(때문에 episode가 길다면 매우 효율적임)

현재상태와 행동, 다음상태와 행동만 주어진다면 업데이트가 가능하기 때문이다.

 

TD(0)를 가정하여 SARSA기법을 Python 코드로 구현하면 다음과 같다.

 

num_eps = 10000
sarsa_rewards = []


for i in range(num_eps):
    
    reward_sum = 0
    env.reset()    
    while True:
        state = env.s
        action = sarsa_agent.get_action(state)
        next_state, reward, done, info = env.step(action)
        next_action = sarsa_agent.get_action(next_state)
        
        sarsa_agent.update_sample(state=state,
                                  action=action,
                                  reward=reward,
                                  next_state=next_state,
                                  next_action=next_action,
                                  done=done)
        reward_sum += reward
        if done:
            break
    
    sarsa_rewards.append(reward_sum)

next_action을 받는점과 sarsa_agent.update_sample을 제외하고는 나머지와 비슷하다.

 


< SARSA 알고리즘 변수 및 함수 설명 >

sarsa_agent.update_sample(state, action, reward, next_state, next_action, done) :

def update_sample(self, state, action, reward, next_state, next_action, done):
    s, a, r, ns, na = state, action, reward, next_state, next_action

    # SARSA target
    td_target = r + self.gamma * self.q[ns, na] * (1 - done)
    self.q[s, a] += self.lr * (td_target - self.q[s, a])

 

TD(0)를 가정한 함수이므로 td_target = $R_{t+1} + \gamma Q(S_{t+1})$이다.

done은 역시 앞 포스트에서 설명한 대로 종결상태 여부를 체크해주는 변수이다.

구해진 td_target을 가지고 $Q(s,a)$를 업데이트 해주면 끝이다.

 

기존의 TD 기법을 알고있으므로 이해가 쉬울것이라 생각된다.

n-step SARSA역시 n-step TD알고리즘과 비슷하게 구현하면 될 것이다.


 

2) SARSA($\lambda$)

SARSA도 TD와 마찬가지로 SARSA($\lambda$)의 형식으로 구현이 가능하다.

그림5. SARSA($\lambda$)

 

이상으로 SARSA기법에 대해 알아보았다.

 


내 생각 : 

더보기

SARSA같은 경우 TD에 대한 기본지식이 있으니 이해하기가 쉬웠다.

그리고 바로 다음의 상태와 다음의 행동만을 가지고 행동 가치함수 Q를 업데이트 할 수 있다는 장점이

현실에서 사용 할 때 큰 이점으로 다가올 것같다.

 

점점더 현실세계에 적용할 수 있을 법한 알고리즘을 배우니 신난다.복습을 빨리 끝내고 프로젝트를 해보고싶은 생각에 들뜬다.

참고자료 :

강화학습 A-Z 강의자료