Q-learning

Q-learning

  最近开始学习各类算法,目前的思路是快速上手code,对算法有值直观的认识,然后再回头看论文公式。

  Q-learning是一种很基础的off-policy强化学习,适合初学者。由于强化学习针对不同问题需要定制场景,因此没有通用的库,需要自己写程序。算法详解可见wiki-Q-learning,这里粘贴一下用于练习的两个case,其中一维case是学习莫烦的,在此基础上扩展了二维,可以直观感受简单的强化学习算法。

简要思路

  1. 准备状态&动作表、奖励表
  2. 随机初始化状态
  3. 根据动作表选择动作
  4. 环境变化
  5. 返回奖励及新状态
  6. 计算上一状态的q值
  7. 根据上衣状态两次q值更新q表
  8. 状态转换

伪代码

  1. Initialize$Q(s,a),∀s ∈ S, a ∈ A(s),$
    ,arbitrarily, and$Q(terminal-state,·)=0$
  2. Repeat for each episode:
  3.     Initialize $S$
  4.     Repeat for each episode:
  5.        Choose $A$from $S$ using policy drived from $Q$(e.g.,$ϵ$-greedy)
  6.        Take action $A$, observe $R,S’$
  7.        $Q(S,A)←Q(S,A)+α[R+γmax_aQ(S’,a)-Q(S,A)]$
  8.        $S ← S’$
  9.     until $S$ is terminal

一维探索游戏

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import numpy as np
import pandas as pd
import time

np.random.seed(2) # reproducible
N_STATES = 6
ACTIONS = ['up','down','left', 'right']
EPSILON = 0.9 # epsilon greedy
ALPHA = 0.1 # learning rate
LAMBDA = 0.9 # discount factor
MAX_EPISODES = 12 # maximum episodes
FRESH_TIME = 0.3 # time of step

def build_q_table(n_states, actions):
table = pd.DataFrame(
np.zeros((n_states, len(actions))), # initial with 0
columns = actions,
)
return table

def choose_action(state, q_table):
state_actions = q_table.iloc[state,:]
# all == 0 means it is the first step
if(np.random.uniform()>EPSILON)or(state_actions.all()==0):
action = np.random.choice(ACTIONS) # greedy mode
else:
action = state_actions.argmax()
return action

def get_env_feedback(S,A): ## action first
R = 0
if A == 'right':
if S == N_STATES-2:
S_next = 'terminal'
R = 1
else:
S_next = S + 1
else:
if S == 0:
S_next = S
R = 0
else:
S_next = S - 1
R = 0
return S_next, R

def update_env(S, episode, step_counter, q_table):

env_list = ['-']*(N_STATES-1)+['T']
if S == 'terminal':
interaction = 'Episode %s:total_steps = %s' % (episode+1, step_counter)
print('\r{}'.format(interaction), end='')
time.sleep(2)
print('\r ', end='')
else:
env_list[S] = '0'
interaction = ''.join(env_list)
print('\r{}'.format(interaction),end = '')
# print('\n', q_table)
time.sleep(FRESH_TIME)
return

def rf():
q_table = build_q_table(N_STATES, ACTIONS)
for episode in range(MAX_EPISODES):
step_counter = 0
S = 0 # initial position
is_terminated = False
update_env(S, episode, step_counter, q_table)
while not is_terminated:
A = choose_action(S,q_table)
S_next, R = get_env_feedback(S,A)
q_predict = q_table.loc[S,A]
if S_next != 'terminal':
q_target = R + LAMBDA*q_table.iloc[S_next,:].max()
else:
q_target = R
is_terminated = True
## 这里学习率如果是1,那么target就直接幅值。
q_table.loc[S,A] += ALPHA*(q_target - q_predict) # 为了长远利益
S = S_next
update_env(S, episode, step_counter + 1, q_table)
step_counter += 1
return q_table

if __name__ == '__main__':
q_table = rf()
print('\r\nQ-table:\r')
print(q_table)

二维探索游戏

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import numpy as np
import pandas as pd
import time

np.random.seed(7) # reproducible
N_STATES = 6,8 # 这是场地大小 6*8
TARGET = 2,3 # 这是目标位置的索引 2,3 注意,其实是三行四列

ACTIONS = ['up','down','left', 'right']
EPSILON = 0.9 # epsilon greedy
ALPHA = 0.1 # learning rate
LAMBDA = 0.9 # discount factor
MAX_EPISODES = 100 # maximum episodes
FRESH_TIME = 0.2 # time of step

def build_q_table(n_states, actions):
table = pd.DataFrame(
np.zeros((n_states[0]*n_states[1], len(actions))), # initial with 0
columns = actions,
)
return table

def build_reward_table(n_states):
table = pd.DataFrame(
np.zeros((n_states[0], n_states[1])), # initial with 0
)
table.iloc[TARGET] = 100
return table

def choose_action(state, q_table):
index_S = state[0]*N_STATES[1] + state[1]
state_actions = q_table.iloc[index_S,:]
# all == 0 means it is the first step
if(np.random.uniform()>EPSILON)or(state_actions.all()==0):
action = np.random.choice(ACTIONS) # greedy mode
else:
action = state_actions.argmax()
return action

def get_env_feedback(S,A,r_table): ## action first
R = 0
if A == 'right':
if S[1]==N_STATES[1]-1:
S_next = S
else:
S_next = S[0],S[1]+1
elif A == 'left':
if S[1]==0:
S_next = S
else:
S_next = S[0],S[1]-1
elif A == 'up':
if S[0]==0:
S_next = S
else:
S_next = S[0]-1,S[1]
elif A == 'down':
if S[0]==N_STATES[0]-1:
S_next = S
else:
S_next = S[0]+1,S[1]
R = r_table.iloc[S_next]
print(S,A,S_next)
return S_next, R

def update_env(S, episode, step_counter, q_table):
env_list = pd.DataFrame(np.zeros((N_STATES[0], N_STATES[1])))
env_list.loc[:] = '-'
env_list.iloc[TARGET] = 'T'
if S == TARGET:
env_list.iloc[S] = '0'
interaction = 'Episode %s:total_steps = %s' % (episode+1, step_counter)
print('\r{}'.format(interaction), end='')
time.sleep(2)
print('\r ', end='')
else:
env_list.iloc[S] = '0'
print('\r{}'.format(env_list),end = '')
# print('\n', q_table)
time.sleep(FRESH_TIME)
print('\n#############NEXT ROUND##################')
return

def rf():
q_table = build_q_table(N_STATES, ACTIONS)
r_table = build_reward_table(N_STATES)
for episode in range(MAX_EPISODES):
step_counter = 0
S = np.random.randint(0,N_STATES[0]),np.random.randint(0,N_STATES[1]) # initial position
index_S = S[0]*N_STATES[1] + S[1]
print('S',S,'To',index_S)
is_terminated = False
update_env(S, episode, step_counter, q_table)
while not is_terminated:
A = choose_action(S,q_table)
S_next, R = get_env_feedback(S,A,r_table)
q_predict = q_table.loc[index_S,A]
if S_next != TARGET:
index_Sn = S_next[0]*N_STATES[1] + S_next[1]
print('S_next',S_next,'To',index_Sn)
q_target = R + LAMBDA*q_table.iloc[index_Sn,:].max()
else:
q_target = R
is_terminated = True
## 这里学习率如果是1,那么target就直接幅值。
q_table.loc[index_S,A] += ALPHA*(q_target - q_predict) # 为了长远利益
S = S_next
index_S = S[0]*N_STATES[1] + S[1]
update_env(S, episode, step_counter + 1, q_table)
step_counter += 1
return q_table

if __name__ == '__main__':
q_table = rf()
print('\r\nQ-table:\r')
print(q_table)

最终可最短路径,并输出q表:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
   0  1  2  3  4  5  6  7
0 - - - - - - - -
1 - - - 0 - - - -
2 - - - T - - - -
3 - - - - - - - -
4 - - - - - - - -
5 - - - - - - - -

Q-table:
up down left right
0 7.307904e-04 1.904577e-03 0.000000e+00 0.007859
1 7.858897e-03 7.556999e-01 0.000000e+00 0.402958
2 2.631690e-02 8.414524e+00 3.626621e-02 1.745023
3 8.100000e-02 2.261374e+01 1.975590e-02 0.013851
4 3.222180e-02 1.211320e+01 8.100000e-02 0.001837
5 2.368521e-03 0.000000e+00 6.363587e-02 0.000218
6 5.904900e-05 1.516812e-01 1.837080e-03 0.000164
7 3.472489e-05 3.289256e-02 5.904900e-05 0.000005
8 7.307904e-04 1.339231e-06 2.994458e-03 1.414500
9 3.988592e-02 4.304672e-08 7.212840e-02 11.509400
10 5.065821e-02 1.484958e-01 3.619704e-04 55.499094
11 2.268000e-01 9.749684e+01 3.653100e-01 5.496623
12 9.978600e-01 7.265790e+00 6.105641e+01 0.215999
13 5.509928e-03 5.940395e-03 2.289216e+01 0.097409
14 5.010639e-04 5.947947e-05 4.372300e+00 0.000478
15 4.903927e-04 1.827626e-05 5.142739e-01 0.002211
16 2.919956e-03 0.000000e+00 0.000000e+00 0.000000
17 2.114380e-01 7.775514e-02 4.782969e-07 3.301809
18 3.003703e+00 1.249748e+01 0.000000e+00 27.100000
19 1.521753e+01 0.000000e+00 9.000000e-01 0.000000
20 6.642014e+00 2.268000e-01 4.685590e+01 0.226800
21 6.600439e-02 2.631690e-02 6.609690e+00 0.000257
22 2.851123e-03 1.605384e-02 2.195100e-01 0.000000
23 1.665721e-02 0.000000e+00 1.975590e-02 0.001809
24 0.000000e+00 1.571279e-01 0.000000e+00 0.000000
25 1.591809e-01 2.455123e+00 0.000000e+00 11.233118
26 1.744314e+00 4.288338e+00 5.544728e-01 79.152640
27 9.972611e+01 4.540708e+00 6.069450e+00 18.925634
28 6.493028e+00 1.346001e+00 8.319825e+01 0.050002
29 2.268000e-01 6.949118e-02 5.408570e+01 0.436982
30 4.782969e-07 1.153245e-03 2.298841e+01 0.000002
31 1.778031e-03 1.661233e-04 0.000000e+00 0.000000
32 9.701053e-05 1.565183e-04 1.436938e-03 2.462021
33 2.041200e-02 5.733493e-03 1.985465e-02 18.339899
34 5.023670e+01 1.246590e-03 1.296716e-02 1.431899
35 2.967056e+01 0.000000e+00 8.428189e+00 0.241540
36 2.075822e+01 0.000000e+00 2.195100e-01 0.332408
37 9.607798e+00 1.826627e-03 3.878280e-02 0.000405
38 4.855764e+00 3.645154e-05 2.368521e-03 0.000019
39 0.000000e+00 6.233213e-06 3.998927e-02 0.000019
40 1.905050e-01 2.736390e-05 7.344711e-04 0.000213
41 4.621959e-01 7.653708e-04 1.102849e-04 0.001247
42 1.385100e-02 1.778031e-03 9.572210e-04 0.100822
43 2.909448e+00 7.163884e-04 0.000000e+00 0.015124
44 9.807243e-01 7.959871e-03 2.040065e-02 0.000961
45 3.069987e-02 0.000000e+00 0.000000e+00 0.000118
46 6.075773e-01 6.925874e-05 4.050171e-04 0.000023
47 4.556314e-04 8.890530e-06 3.241323e-02 0.000000