TensorFlow DQNAgent collect_policy InvalidArgumentError 解決方案:TimeStepSpec 與 TimeStep 張量形狀匹配指南

TensorFlow DQNAgent collect_policy InvalidArgumentError 解決方案:TimeStepSpec 與 TimeStep 張量形狀匹配指南

在 tf_agents 框架中構建強化學習代理時,開發者可能會遇到一個常見的 InvalidArgumentError,尤其是在調用 DqnAgent 的 collect_policy.action() 方法時。這個錯誤通常表現為 {{function_node __wrapped__select_device_/job:localhost/replica:0/task:0/device:CPU:0}} ‘then’ and ‘else’ must have the same size. but received: [1] vs. [] [Op:Select] name:。奇怪的是,代理的 policy.action() 方法可能運行正常,這使得問題更加難以定位。本文將深入剖析此錯誤的原因,并提供詳細的解決方案和最佳實踐。

核心概念:TimeStepSpec 與 TimeStep 張量

理解這個錯誤的關鍵在于區分 tf_agents 中 timestepspec 的作用和實際 timestep 張量的結構。

  1. TimeStepSpec (時間步規范)

    • TimeStepSpec 定義了環境中每個時間步的數據結構和預期的數據類型與形狀。
    • 關鍵點:TensorSpec 在 TimeStepSpec 中定義的 shape 不包含批處理維度。它描述的是單個樣本的形狀。例如,如果獎勵是一個標量,其 TensorSpec 的 shape 應該是 ()。如果觀測是一個形狀為 (6,) 的向量,其 TensorSpec 的 shape 應該是 (6,)。
  2. TimeStep 張量 (實際時間步數據)

    • 實際傳遞給代理策略的 TimeStep 對象包含具體的 tensorflow 張量數據。
    • 關鍵點:這些張量必須包含批處理維度。即使 batch_size 為 1,也需要一個顯式的批處理維度。例如,一個標量獎勵在 batch_size=1 時,其張量形狀應為 (1,)。一個形狀為 (6,) 的觀測向量在 batch_size=1 時,其張量形狀應為 (1, 6)。

問題根源:batch_size=1 時的形狀誤解

InvalidArgumentError: ‘then’ and ‘else’ must have the same size. but received: [1] vs. [] 這個錯誤通常發生在 tf_agents 內部的 tf.where 操作中,尤其是在 EpsilonGreedyPolicy(collect_policy 通常是這種策略)的探索邏輯里。tf.where(condition, x, y) 函數要求 x 和 y 具有兼容的形狀。當 TimeStepSpec 和實際 TimeStep 張量的形狀定義出現不一致時,就會導致這種錯誤。

最常見的誤區是,當某個時間步組件(如 step_type, reward, discount)本質上是標量時,在 TimeStepSpec 中將其 shape 定義為 (1,) 而非 ()。

  • 錯誤定義示例 (在 TimeStepSpec 中)

    reward = tensor_spec.TensorSpec(shape=(1,), dtype=tf.float32) # 錯誤:期望單樣本是一個1維向量

    這里,shape=(1,) 意味著每個樣本是一個包含一個元素的1維向量。然而,reward 通常是一個標量。當 collect_policy 內部處理這些標量值時,如果它期望一個真正的標量(即形狀為 ()),但實際從 TimeStep 張量中得到的是一個形狀為 (1,) 的張量,或者在與 TensorSpec 匹配時發生隱式形狀轉換,就可能導致 [1] vs [] 的不匹配。

  • 實際張量 (當 batch_size=1 時)

    reward_tensor = tf.convert_to_tensor([reward_value], dtype=tf.float32) # 形狀為 (1,)

    這個 (1,) 形狀是正確的,它表示一個批次中包含一個標量。問題在于 TimeStepSpec 沒有正確地將它視為一個批次中的標量。

解決方案:正確配置 TimeStepSpec 與 TimeStep 張量

解決此問題的關鍵在于確保 TimeStepSpec 中的 shape 定義反映單個樣本的真實形狀,并且實際 TimeStep 張量在構造時包含正確的批處理維度。

1. 對于標量數據 (step_type, reward, discount)

這些組件在每個時間步通常都是單個數值(標量)。

  • TimeStepSpec 中的正確定義: 將 shape=(1,) 更正為 shape=(),表示一個標量。

    from tf_agents.specs import tensor_spec import tensorflow as tf from tf_agents.trajectories.time_step import TimeStep  # ... 其他導入 ...  # 修正后的 TimeStepSpec 定義 time_step_spec = TimeStep(     step_type=tensor_spec.BoundedTensorSpec(shape=(), dtype=tf.int32, minimum=0, maximum=2),     reward=tensor_spec.TensorSpec(shape=(), dtype=tf.float32),     discount=tensor_spec.TensorSpec(shape=(), dtype=tf.float32),     # 觀測值根據其本身的維度定義,見下文     observation=tensor_spec.TensorSpec(shape=(amountMachines,), dtype=tf.int32) )
  • 實際 TimeStep 張量的正確構造 (針對 batch_size=1): 使用 tf.convert_to_tensor([value], dtype=…) 來創建一個形狀為 (1,) 的張量,表示一個批次中包含一個標量。

    # 假設 step_type_value, reward_value, discount_value 是python標量 time_step = TimeStep(     step_type=tf.convert_to_tensor([step_type_value], dtype=tf.int32), # 形狀 (1,)     reward=tf.convert_to_tensor([reward_value], dtype=tf.float32),     # 形狀 (1,)     discount=tf.convert_to_tensor([discount_value], dtype=tf.float32), # 形狀 (1,)     observation=current_state_batch # 觀測值構造見下文 )

2. 對于多維觀測數據 (observation)

觀測值通常是向量或更高維的張量。

  • TimeStepSpec 中的正確定義: shape 應反映單個觀測樣本的維度。例如,如果 current_state 是一個形狀為 (amountMachines,) 的 numpy 數組,那么 TensorSpec 的 shape 應該是 (amountMachines,)。
     # ... 在 time_step_spec 定義中 ... observation=

? 版權聲明
THE END
喜歡就支持一下吧
點贊11 分享