Computer Science/Python

[Tensorflow] 텐서플로우 자료형

Taewon Heo 2017. 6. 3. 19:26

[Tensorflow] 텐서플로우 자료형


텐서플로우에서는 일반적으로 파이썬에서 하듯이 x = 3 이렇게 변수를 정의하면 안된다.


텐서플로우의 자료형


상수형(Constant)


import tensorflow as tf


a = tf.constant([3], dtype=tf.float32)

b = tf.constant([5], dtype=tf.float32)


이런 식으로 tf.~로 정의하는데 짜증나는건 dtype도 float32가 아니라 tf.float32이다.

a의 타입을 찍어보면

print type(a)

>> <class 'tensorflow.python.framework.ops.Tensor'>


연산가능한 타입이 아니므로, a+b 와 같은 일반적인 방법으로는 연산이 안된다는 것을 알 수 있다.


그래프와 세션

그러면 a + b 의 값을 얻고 싶으면 어떻게 해야할까?
c = a + b라고 할 때, a + b를 그래프(Graph)라고 한다.

실제로 값을 뽑아내려면, 세션(Session)에 그래프를 넣어서 실행해야 한다.

session = tf.Session()
result = session.run(b)

print result
>> [5.]
드디어 숫자가 눈에 보인다. 하지만 왜 배열 형태일까?
맨 처음에 a = tf.constant([3], ...)이렇게 정의했기 때문이다.
만약에 tf.constant([1,2,3],~)의 배열의 개수가 1이 아니엇다면 배열의 개수만큼 결과값이 나온다.
연산하려는 행렬의 개수가 안맞으면 에러가 발생한다.


플레이스홀더

그렇다면 상수 대신에 변수를 이용하여 연산을 하려면 어떻게 해야할까?
여기서 Placeholeder라는 것을 사용한다.
플레이스홀더는 Variable을 담는 그릇이고 여기에는 학습 데이터를 피딩(하나씩 넣음)한다.
tf.placeholder(shape, dtype, name) 여기엔 shape, dtype, name의 속성이 필요하다.

import tensorflow as tf

x = tf.placeholder("float", None)
y = x * 2

with tf.Session() as session:
    result = session.run(y, feed_dict={x: [1, 2, 3]})
    print(result)

http://learningtensorflow.com/lesson4/에서 참조한 예제이다.

x는 float type의 placeholder이고 y = x * 2라는 그래프가 있다.

tf.Session()을 session이라고 정의하여 result = session.run(Graph, feed_dict) 로 실행한다.
상수 연산만 할 때(feed_dict가 없을 때) 였다면 run(y)만 했을 것이다. 

여기서는 x라는 플레이스홀더에 데이터를 넣어줘야 하기 때문에 run(y, 넣어줄 데이터)를 한다.
with문 대신 이렇게 써도 상관없다.
session = tf.Session()
result = session.run(y, feed_dict={x: [1, 2, 3]})
print(result)  

변수형

y = W * x라는 그래프를 만들어보자.
import tensorflow as tf

input_data = [1,2,3,4]
x = tf.placeholder("float", None)
W = tf.Variable([2],dtype=tf.float32)
y = W * x

session = tf.Session()
init = tf.initialize_all_variables()
session.run(init)
result = session.run(y, feed_dict={x:input_data})
print(result)

여기서 주의해야할 점은 Variable은 변수 초기화를 해줘야 [2]라는 값이 W에 저장된다는 것이다.
저 코드 없이 실행하면 에러가 발생한다.

텐서플로우는 일반적인 파이썬 자료형과 다르기 때문에
기본 개념을 알고 넘어가야 머신러닝이 수월해질 것 같아서 정리했다.