Fix args for trppo
This commit is contained in:
		
							parent
							
								
									5dfc162d3c
								
							
						
					
					
						commit
						0bdc528224
					
				@ -145,6 +145,27 @@ class Model(object):
 | 
			
		||||
            ) * ADVS
 | 
			
		||||
            pg_loss = -tf.reduce_mean(pg_targets)
 | 
			
		||||
            clipfrac = tf.reduce_mean(tf.to_float(tf.greater(tf.abs(ratio - 1.0), CLIPRANGE)))
 | 
			
		||||
        elif cliptype == ClipType.ratio_rollback_constant:
 | 
			
		||||
            slope = args.clipargs.slope_rollback
 | 
			
		||||
            pg_targets = tf.where(
 | 
			
		||||
                ADVS >= 0,
 | 
			
		||||
                tf.where( ratio <= 1 + CLIPRANGE,
 | 
			
		||||
                                ratio*ADVS,
 | 
			
		||||
                                slope * ratio ), # When ratio=1+CLIPRANGE, the corresponding value should also be 1+CLIPRANGE
 | 
			
		||||
                tf.where( ratio >= 1 - CLIPRANGE,
 | 
			
		||||
                                ratio*ADVS,
 | 
			
		||||
                                -slope * ratio )
 | 
			
		||||
            )
 | 
			
		||||
            pg_loss = -tf.reduce_mean(pg_targets)
 | 
			
		||||
            clipfrac = tf.reduce_mean(tf.to_float(tf.greater(tf.abs(ratio - 1.0), CLIPRANGE)))
 | 
			
		||||
        elif cliptype == ClipType.ratio_strict:
 | 
			
		||||
            pg_losses2 = -ADVS * tf.clip_by_value(ratio, 1.0 - CLIPRANGE, 1.0 + CLIPRANGE)
 | 
			
		||||
            pg_loss = tf.reduce_mean(pg_losses2)
 | 
			
		||||
            clipfrac = tf.reduce_mean(tf.to_float(tf.greater(tf.abs(ratio - 1.0), CLIPRANGE)))
 | 
			
		||||
        elif cliptype == ClipType.a2c:
 | 
			
		||||
            pg_losses = -ADVS * ratio
 | 
			
		||||
            pg_loss = tf.reduce_mean(pg_losses)
 | 
			
		||||
            clipfrac = tf.constant(0)
 | 
			
		||||
 | 
			
		||||
        elif cliptype == ClipType.kl:
 | 
			
		||||
            # version by hugo
 | 
			
		||||
@ -198,6 +219,52 @@ class Model(object):
 | 
			
		||||
            pg_loss = -tf.reduce_mean(pg_targets)
 | 
			
		||||
            clipfrac = tf.reduce_mean(
 | 
			
		||||
                tf.to_float(tf.logical_and(kl >= KLRANGE, ratio*ADVS>ADVS)))
 | 
			
		||||
 | 
			
		||||
        elif cliptype == ClipType.kl_klrollback_constant:
 | 
			
		||||
            # The slope of the objective is switched once the kl exceed.
 | 
			
		||||
            slope = args.clipargs.slope_rollback
 | 
			
		||||
            # version by hugo
 | 
			
		||||
            # pg_losses = tf.where(kl <= KLRANGE, ADV * ratio,
 | 
			
		||||
            #                      tf.where(tf.logical_and(ratio > 1., ADV > 0), slope * ratio * ADV,
 | 
			
		||||
            #                       tf.where(tf.logical_and(ratio < 1., ADV < 0.), slope * ratio * ADV, ADV * ratio)))
 | 
			
		||||
            # version by siuming
 | 
			
		||||
            pg_targets = tf.where(
 | 
			
		||||
                tf.logical_and( kl >= KLRANGE, ratio * ADVS > 1 * ADVS),
 | 
			
		||||
                slope * kl,
 | 
			
		||||
                ratio * ADVS
 | 
			
		||||
            )
 | 
			
		||||
            pg_loss = -tf.reduce_mean(pg_targets)
 | 
			
		||||
            clipfrac = tf.reduce_mean(
 | 
			
		||||
                tf.to_float(tf.logical_and(kl >= KLRANGE, ratio*ADVS>ADVS)))
 | 
			
		||||
 | 
			
		||||
        elif cliptype == ClipType.kl_klrollback:
 | 
			
		||||
            # The slope of the objective is switched once the kl exceed.
 | 
			
		||||
            slope = args.clipargs.slope_rollback
 | 
			
		||||
            # version by hugo
 | 
			
		||||
            # pg_losses = tf.where(kl <= KLRANGE, ADV * ratio,
 | 
			
		||||
            #                      tf.where(tf.logical_and(ratio > 1., ADV > 0), slope * ratio * ADV,
 | 
			
		||||
            #                       tf.where(tf.logical_and(ratio < 1., ADV < 0.), slope * ratio * ADV, ADV * ratio)))
 | 
			
		||||
            # version by siuming
 | 
			
		||||
            pg_targets = tf.where(
 | 
			
		||||
                tf.logical_and( kl >= KLRANGE, ratio * ADVS > 1 * ADVS),
 | 
			
		||||
                slope * kl * tf.abs(ADVS),
 | 
			
		||||
                ratio * ADVS
 | 
			
		||||
            )
 | 
			
		||||
            pg_loss = -tf.reduce_mean(pg_targets)
 | 
			
		||||
            clipfrac = tf.reduce_mean(
 | 
			
		||||
                tf.to_float(tf.logical_and(kl >= KLRANGE, ratio*ADVS>ADVS)))
 | 
			
		||||
        elif cliptype == ClipType.kl_strict:
 | 
			
		||||
            pg_losses = -ADVS * ratio
 | 
			
		||||
            pg_losses = tf.where(
 | 
			
		||||
                kl >= KLRANGE,
 | 
			
		||||
                tf.stop_gradient(pg_losses, name='pg_losses_notrain'),
 | 
			
		||||
                pg_losses
 | 
			
		||||
            )
 | 
			
		||||
            pg_loss = tf.reduce_mean(pg_losses)
 | 
			
		||||
            clipfrac = tf.reduce_mean(tf.to_float( kl >= KLRANGE ))
 | 
			
		||||
        elif cliptype == ClipType.adaptivekl:
 | 
			
		||||
            pg_loss = tf.reduce_mean(-ADVS * ratio) + tf.reduce_mean(kl) * KL_COEF
 | 
			
		||||
            clipfrac = tf.constant(0.)
 | 
			
		||||
        else:
 | 
			
		||||
            raise NotImplementedError
 | 
			
		||||
 | 
			
		||||
@ -513,93 +580,7 @@ def learn(*, policy, env, env_eval, n_steps, total_timesteps, ent_coef, lr,
 | 
			
		||||
            actions=actions, values_old=values, neglogpacs_old=neglogpacs, advs=advs,
 | 
			
		||||
            policyflats_old=policyflats
 | 
			
		||||
        )
 | 
			
		||||
        if cliptype == ClipType.adaptivekl:
 | 
			
		||||
            kwargs_in_scalar.update(
 | 
			
		||||
                kl_coef=kl_coef
 | 
			
		||||
            )
 | 
			
		||||
        elif cliptype in [ClipType.kl2clip, ClipType.kl2clip_rollback]:
 | 
			
		||||
            pas = np.exp(-neglogpacs)
 | 
			
		||||
 | 
			
		||||
            if isinstance(env.action_space, gym.spaces.box.Box):
 | 
			
		||||
                results_kl2clip = kl2clip(
 | 
			
		||||
                    mu0_logstd0=policyflats, a=actions, pas=pas,
 | 
			
		||||
                    delta=args.clipargs.klrange,
 | 
			
		||||
                    adjusttype=args.clipargs.adjusttype, cliprange=args.clipargs.cliprange,
 | 
			
		||||
                    require_sol = False,
 | 
			
		||||
                    verbose = 1
 | 
			
		||||
                    # sharelogstd=args.clipargs, clip_clipratio=args.kl2clip_clip_clipratio,
 | 
			
		||||
                )
 | 
			
		||||
                cliprange_upper = results_kl2clip.ratio.max
 | 
			
		||||
                cliprange_lower = results_kl2clip.ratio.min
 | 
			
		||||
 | 
			
		||||
                if args.clipargs.adaptive_range == '2constant':
 | 
			
		||||
                    cliprange_upper = cliprange_upper - (cliprange_upper - cliprange_upper_min) * frac
 | 
			
		||||
                    cliprange_lower = cliprange_lower + (cliprange_lower_max - cliprange_lower) * frac
 | 
			
		||||
                    # TODO: debug tmp
 | 
			
		||||
                    debugs['cliprange_upper_min'] = cliprange_upper_min
 | 
			
		||||
                    debugs['cliprange_lower_max'] = cliprange_lower_max
 | 
			
		||||
                elif args.clipargs.adaptive_range == '2cliprange_final':
 | 
			
		||||
                    cliprange_final = args.clipargs.cliprange_final
 | 
			
		||||
                    cliprange_upper = cliprange_upper - (cliprange_upper - (1 + cliprange_final)) * frac
 | 
			
		||||
                    cliprange_lower = cliprange_lower + ((1 - cliprange_final) - cliprange_lower) * frac
 | 
			
		||||
                elif args.clipargs.adaptive_range == '2constant_upper':
 | 
			
		||||
                    cliprange_upper = cliprange_upper - (cliprange_upper - cliprange_upper_min) * frac
 | 
			
		||||
                    # TODO: debug tmp
 | 
			
		||||
                    debugs['cliprange_upper_min'] = cliprange_upper_min
 | 
			
		||||
                    debugs['cliprange_lower_max'] = cliprange_lower_max
 | 
			
		||||
                # elif args.clipargs.adaptive_range == '2cliprange':
 | 
			
		||||
                #     # TODO: cliprange may be none....
 | 
			
		||||
                #     cliprange_upper = cliprange_upper - (cliprange_upper- (1+args.clipargs.cliprange))*frac
 | 
			
		||||
                #     cliprange_lower = cliprange_lower + ((1-args.clipargs.cliprange) - cliprange_lower )*frac
 | 
			
		||||
                elif args.clipargs.adaptive_range == 'clip2cliprange':
 | 
			
		||||
                    frac_threshold = args.clipargs.frac_threshold
 | 
			
		||||
                    if frac >= frac_threshold:
 | 
			
		||||
                        cliprange_upper[:] = 1 + args.clipargs.cliprange
 | 
			
		||||
                        cliprange_lower[:] = 1 - args.clipargs.cliprange
 | 
			
		||||
            elif isinstance(env.action_space, gym.spaces.discrete.Discrete):
 | 
			
		||||
                results_kl2clip = kl2clip(
 | 
			
		||||
                    pas=pas,
 | 
			
		||||
                    delta=args.clipargs.klrange,
 | 
			
		||||
                    verbose = 1
 | 
			
		||||
                )
 | 
			
		||||
                cliprange_upper = results_kl2clip.ratio.max
 | 
			
		||||
                cliprange_lower = results_kl2clip.ratio.min
 | 
			
		||||
 | 
			
		||||
                cliprange_upper = 1 + (cliprange_upper - 1) * frac_remain
 | 
			
		||||
                cliprange_lower = 1 - (1. - cliprange_lower) * frac_remain
 | 
			
		||||
 | 
			
		||||
            # if isinstance(env.action_space, gym.spaces.Discrete):
 | 
			
		||||
            #     raise NotImplemented('Please review the code')
 | 
			
		||||
            #     cliprange_upper = 1 + (cliprange_upper - 1) * frac_remain
 | 
			
		||||
            #     cliprange_lower = 1 - (1. - cliprange_lower) * frac_remain
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
            debugs['cliprange_upper'] = cliprange_upper
 | 
			
		||||
            debugs['cliprange_lower'] = cliprange_lower
 | 
			
		||||
            kwargs_in_arr.update(
 | 
			
		||||
                cliprange_upper=cliprange_upper,
 | 
			
		||||
                cliprange_lower=cliprange_lower,
 | 
			
		||||
            )
 | 
			
		||||
        elif cliptype == ClipType.adaptiverange_advantage:
 | 
			
		||||
            cliprange_max = args.clipargs.cliprange_max
 | 
			
		||||
            # positive
 | 
			
		||||
            advs_positive = advs[advs>0]
 | 
			
		||||
            adv_upper = np.median(advs_positive) * 2#use median, avoide the affect of overlarge values
 | 
			
		||||
            cliprange_upper = np.minimum( np.abs(advs), adv_upper)
 | 
			
		||||
            cliprange_upper = 1 + cliprange_upper / cliprange_upper.max() * cliprange_max
 | 
			
		||||
 | 
			
		||||
            advs_negative = advs[advs<0]
 | 
			
		||||
            adv_lower = np.median(advs_negative) *2
 | 
			
		||||
            cliprange_lower = np.maximum( -np.abs(advs), adv_lower )
 | 
			
		||||
            cliprange_lower = 1 - cliprange_lower/cliprange_lower.min() * cliprange_max
 | 
			
		||||
 | 
			
		||||
            debugs['cliprange_upper'] = cliprange_upper
 | 
			
		||||
            debugs['cliprange_lower'] = cliprange_lower
 | 
			
		||||
            kwargs_in_arr.update(
 | 
			
		||||
                cliprange_upper=cliprange_upper,
 | 
			
		||||
                cliprange_lower=cliprange_lower,
 | 
			
		||||
            )
 | 
			
		||||
        elif cliptype in [ClipType.kl, ClipType.kl_ratiorollback, ClipType.kl_klrollback_constant, ClipType.kl_klrollback, ClipType.kl_strict, ClipType.kl_klrollback_constant_withratio]:
 | 
			
		||||
        if cliptype in [ClipType.kl, ClipType.kl_ratiorollback, ClipType.kl_klrollback_constant, ClipType.kl_klrollback, ClipType.kl_strict, ClipType.kl_klrollback_constant_withratio]:
 | 
			
		||||
            klrange = args.clipargs.klrange
 | 
			
		||||
            if 'decay_threshold' in args.clipargs.keys():
 | 
			
		||||
                decay_threshold = args.clipargs.decay_threshold
 | 
			
		||||
@ -607,16 +588,6 @@ def learn(*, policy, env, env_eval, n_steps, total_timesteps, ent_coef, lr,
 | 
			
		||||
                    coef_ = frac_remain/(1-decay_threshold)
 | 
			
		||||
                    klrange *= coef_
 | 
			
		||||
            kwargs_in_scalar.update( klrange = klrange  )
 | 
			
		||||
        elif cliptype in [ClipType.wasserstein, ClipType.wasserstein_rollback_constant, ClipType.totalvariation, ClipType.totalvariation_rollback_constant ]:
 | 
			
		||||
            range_ = args.clipargs.range
 | 
			
		||||
            if 'decay_threshold' in args.clipargs.keys():
 | 
			
		||||
                decay_threshold = args.clipargs.decay_threshold
 | 
			
		||||
                if frac >= decay_threshold:
 | 
			
		||||
                    coef_ = frac_remain/(1-decay_threshold)
 | 
			
		||||
                    range_ *= coef_
 | 
			
		||||
            kwargs_in_scalar.update( range = range_  )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
        # print(kwargs_in_scalar)
 | 
			
		||||
        # ----------------- Train the model
 | 
			
		||||
        mblossvals = []
 | 
			
		||||
 | 
			
		||||
@ -140,7 +140,10 @@ def arg_parser_common():
 | 
			
		||||
 | 
			
		||||
                    a2c=dict(clipargs=dict(cliprange=0.1)),
 | 
			
		||||
 | 
			
		||||
                    kl=dict(clipargs=dict(klrange=0.001, cliprange=0.1, decay_threshold=0.)),
 | 
			
		||||
                    kl=dict(
 | 
			
		||||
                        clipargs=dict(klrange=0.001, cliprange=0.1, decay_threshold=0.),
 | 
			
		||||
                        coef_entropy = 0
 | 
			
		||||
                    ),
 | 
			
		||||
                    kl_ratiorollback=dict(clipargs=dict(klrange=0.001,slope_rollback=-0.05, cliprange=0.1, decay_threshold=0.)),
 | 
			
		||||
                    kl_klrollback_constant=dict(clipargs=dict(klrange=0.001, slope_rollback=-0.05, cliprange=0.1, decay_threshold=0.)),
 | 
			
		||||
                    kl_klrollback_constant_withratio= dict(
 | 
			
		||||
@ -153,7 +156,8 @@ def arg_parser_common():
 | 
			
		||||
                        clipargs=dict(range=0.02, slope_rollback=-0.05, cliprange=0.1, decay_threshold=0.)
 | 
			
		||||
                    ),
 | 
			
		||||
                    kl2clip=dict(
 | 
			
		||||
                        clipargs=dict(klrange=0.001, cliprange=0.1, kl2clip_opttype='tabular', adaptive_range='')
 | 
			
		||||
                        clipargs=dict(klrange=0.001, cliprange=0.1, kl2clip_opttype='tabular', adaptive_range=''),
 | 
			
		||||
                        coef_entropy=0
 | 
			
		||||
                    ),
 | 
			
		||||
                    adaptivekl=dict(
 | 
			
		||||
                        clipargs=dict(klrange=0.01, cliprange=0.1)
 | 
			
		||||
@ -236,7 +240,7 @@ def main():
 | 
			
		||||
            args.env_full = f'{args.env}-v4'
 | 
			
		||||
    tools.warn_(f'Run with setting for {args.envtype} task!!!!!')
 | 
			
		||||
 | 
			
		||||
    assert bool(args.alg) != bool(args.cliptype), 'Only one arg can be specified'
 | 
			
		||||
    assert bool(args.alg) != bool(args.cliptype), 'Either alg or cliptype should be specified'
 | 
			
		||||
    if args.alg: # For release
 | 
			
		||||
        args.cliptype = alg2cliptype[args.alg]
 | 
			
		||||
        keys_exclude.append('cliptype')
 | 
			
		||||
@ -297,7 +301,7 @@ def main():
 | 
			
		||||
    # TODO prepare_dir: change .finish_indicator to finishi_indictator, which is more clear.
 | 
			
		||||
    # --- prepare dir
 | 
			
		||||
    import baselines
 | 
			
		||||
    root_dir = tools_logger.get_logger_dir(  'baselines', baselines, 'results' )
 | 
			
		||||
    root_dir = tools_logger.get_logger_dir(  'baselines', 'results', baselines )
 | 
			
		||||
    args = tools_logger.prepare_dirs( args, key_first='env', keys_exclude=keys_exclude, dirs_type=['log' ], root_dir=root_dir )
 | 
			
		||||
    # --- prepare args for use
 | 
			
		||||
    args.cliptype = ClipType[ args.cliptype ]
 | 
			
		||||
@ -384,4 +388,4 @@ def main():
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == '__main__':
 | 
			
		||||
    main()
 | 
			
		||||
    main()
 | 
			
		||||
							
								
								
									
										2
									
								
								toolsm
									
									
									
									
									
								
							
							
								
								
								
								
								
								
									
									
								
							
						
						
									
										2
									
								
								toolsm
									
									
									
									
									
								
							@ -1 +1 @@
 | 
			
		||||
Subproject commit 88306bfb6c09c73b07b8cfba2ef5db1dc686066c
 | 
			
		||||
Subproject commit 5221b220cf12ef6aac64d8910aab23df35a91df6
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user